aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@yahoo-inc.com>2017-11-14 11:11:57 +0100
committerLester Solbakken <lesters@yahoo-inc.com>2017-11-14 11:11:57 +0100
commitfa7d6c2ec6180d69568a75a7293bc97294a5c811 (patch)
tree4d0fc3b998d455e9c2a4f4d6f766104eada1c56b /vespajlib
parentfea33aaa04ad925e2af4387ec63f59f6d9531c3d (diff)
Fix 'Tensors cannot be compared with ~='
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java17
1 files changed, 16 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8fc80e3b440..10098e24e76 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -177,6 +177,7 @@ public interface Tensor {
default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); }
default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); }
+ default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); }
default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); }
default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
@@ -261,11 +262,25 @@ public interface Tensor {
Cell aCell = aIterator.next();
double aValue = aCell.getValue();
double bValue = b.get(aCell.getKey());
- if (Math.abs(aValue-bValue) > 1e-7) return false; // TODO: determine relative precision
+ if (!approxEquals(aValue, bValue, 1e-6)) return false;
}
return true;
}
+ static boolean approxEquals(double x, double y, double tolerance) {
+ return Math.abs(x-y) < tolerance;
+ }
+
+ static boolean approxEquals(double x, double y) {
+ if (y < -1.0 || y > 1.0) {
+ x = Math.nextAfter(x/y, 1.0);
+ y = 1.0;
+ } else {
+ x = Math.nextAfter(x, y);
+ }
+ return x==y;
+ }
+
// ----------------- Factories
/**