aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java17
2 files changed, 17 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 6cf15837da1..88abbe279aa 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -103,6 +103,7 @@ public class TensorValue extends Value {
case SMALLEREQUAL: return value.smallerOrEqual(argument);
case EQUAL: return value.equal(argument);
case NOTEQUAL: return value.notEqual(argument);
+ case APPROX_EQUAL: return value.approxEqual(argument);
default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8fc80e3b440..10098e24e76 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -177,6 +177,7 @@ public interface Tensor {
default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); }
default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); }
+ default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); }
default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); }
default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
@@ -261,11 +262,25 @@ public interface Tensor {
Cell aCell = aIterator.next();
double aValue = aCell.getValue();
double bValue = b.get(aCell.getKey());
- if (Math.abs(aValue-bValue) > 1e-7) return false; // TODO: determine relative precision
+ if (!approxEquals(aValue, bValue, 1e-6)) return false;
}
return true;
}
+ static boolean approxEquals(double x, double y, double tolerance) {
+ return Math.abs(x-y) < tolerance;
+ }
+
+ static boolean approxEquals(double x, double y) {
+ if (y < -1.0 || y > 1.0) {
+ x = Math.nextAfter(x/y, 1.0);
+ y = 1.0;
+ } else {
+ x = Math.nextAfter(x, y);
+ }
+ return x==y;
+ }
+
// ----------------- Factories
/**