diff options
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java | 1 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 17 |
2 files changed, 17 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 6cf15837da1..88abbe279aa 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -103,6 +103,7 @@ public class TensorValue extends Value { case SMALLEREQUAL: return value.smallerOrEqual(argument); case EQUAL: return value.equal(argument); case NOTEQUAL: return value.notEqual(argument); + case APPROX_EQUAL: return value.approxEqual(argument); default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8fc80e3b440..10098e24e76 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -177,6 +177,7 @@ public interface Tensor { default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); } default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } + default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); } default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); } default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } @@ -261,11 +262,25 @@ public interface Tensor { Cell aCell = aIterator.next(); double aValue = aCell.getValue(); double bValue = b.get(aCell.getKey()); - if (Math.abs(aValue-bValue) > 1e-7) return false; // TODO: determine relative precision + if (!approxEquals(aValue, bValue, 1e-6)) return false; } return true; } + static boolean approxEquals(double x, double y, double tolerance) { + return Math.abs(x-y) < tolerance; + } + + static boolean approxEquals(double x, double y) { + if (y < -1.0 || y > 1.0) { + x = Math.nextAfter(x/y, 1.0); + y = 1.0; + } else { + x = Math.nextAfter(x, y); + } + return x==y; + } + // ----------------- Factories /** |