diff options
author | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-14 11:11:57 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-14 11:11:57 +0100 |
commit | fa7d6c2ec6180d69568a75a7293bc97294a5c811 (patch) | |
tree | 4d0fc3b998d455e9c2a4f4d6f766104eada1c56b /vespajlib | |
parent | fea33aaa04ad925e2af4387ec63f59f6d9531c3d (diff) |
Fix 'Tensors cannot be compared with ~='
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8fc80e3b440..10098e24e76 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -177,6 +177,7 @@ public interface Tensor { default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); } default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } + default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); } default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); } default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } @@ -261,11 +262,25 @@ public interface Tensor { Cell aCell = aIterator.next(); double aValue = aCell.getValue(); double bValue = b.get(aCell.getKey()); - if (Math.abs(aValue-bValue) > 1e-7) return false; // TODO: determine relative precision + if (!approxEquals(aValue, bValue, 1e-6)) return false; } return true; } + static boolean approxEquals(double x, double y, double tolerance) { + return Math.abs(x-y) < tolerance; + } + + static boolean approxEquals(double x, double y) { + if (y < -1.0 || y > 1.0) { + x = Math.nextAfter(x/y, 1.0); + y = 1.0; + } else { + x = Math.nextAfter(x, y); + } + return x==y; + } + // ----------------- Factories /** |