diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-25 15:15:00 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-25 15:15:00 +0100 |
commit | ddbc4ca9c6a785c4b91c49c0a5b20c5b545598c6 (patch) | |
tree | 63377493aaa49533292c85bd2ce80299da46262d /vespajlib | |
parent | a8922fadc07600065114606fbc0115c30c4cf2dc (diff) |
Correct sqrt
Diffstat (limited to 'vespajlib')
3 files changed, 4 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 9b248d2e528..0e96b43bd22 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -23,10 +23,10 @@ public class L2Normalize extends CompositeTensorFunction { public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); return new Join(primitiveArgument, - new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.sqrt()), + new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()), Reduce.Aggregator.sum, dimension), - ScalarFunctions.square()), + ScalarFunctions.sqrt()), ScalarFunctions.divide()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index f4aa7e2fe06..9438c6c533a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -16,7 +16,7 @@ public class ScalarFunctions { public static DoubleBinaryOperator multiply() { return new Multiplication(); } public static DoubleBinaryOperator divide() { return new Division(); } public static DoubleUnaryOperator square() { return new Square(); } - public static DoubleUnaryOperator sqrt() { return new Square(); } + public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator exp() { return new Exponent(); } public static class Addition implements DoubleBinaryOperator { @@ -61,7 +61,7 @@ public class ScalarFunctions { public static class Sqrt implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return operand * operand; } + public double applyAsDouble(double operand) { return Math.sqrt(operand); } @Override public String toString() { return "f(a)(sqrt(a))"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index aee8cedee17..b05b8172b42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -22,7 +22,6 @@ public class Softmax extends CompositeTensorFunction { @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); - // join(map(t, f(x)(exp(x))), reduce(map(t, f(x)(exp(x))), "sum", "dimension"), f(x,y)(x / y)) return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()), Reduce.Aggregator.sum, |