summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-25 15:15:00 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-25 15:15:00 +0100
commitddbc4ca9c6a785c4b91c49c0a5b20c5b545598c6 (patch)
tree63377493aaa49533292c85bd2ce80299da46262d /vespajlib
parenta8922fadc07600065114606fbc0115c30c4cf2dc (diff)
Correct sqrt
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java1
3 files changed, 4 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index 9b248d2e528..0e96b43bd22 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -23,10 +23,10 @@ public class L2Normalize extends CompositeTensorFunction {
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument = argument.toPrimitive();
return new Join(primitiveArgument,
- new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.sqrt()),
+ new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
Reduce.Aggregator.sum,
dimension),
- ScalarFunctions.square()),
+ ScalarFunctions.sqrt()),
ScalarFunctions.divide());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index f4aa7e2fe06..9438c6c533a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -16,7 +16,7 @@ public class ScalarFunctions {
public static DoubleBinaryOperator multiply() { return new Multiplication(); }
public static DoubleBinaryOperator divide() { return new Division(); }
public static DoubleUnaryOperator square() { return new Square(); }
- public static DoubleUnaryOperator sqrt() { return new Square(); }
+ public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator exp() { return new Exponent(); }
public static class Addition implements DoubleBinaryOperator {
@@ -61,7 +61,7 @@ public class ScalarFunctions {
public static class Sqrt implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return operand * operand; }
+ public double applyAsDouble(double operand) { return Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(sqrt(a))"; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index aee8cedee17..b05b8172b42 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -22,7 +22,6 @@ public class Softmax extends CompositeTensorFunction {
@Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument = argument.toPrimitive();
- // join(map(t, f(x)(exp(x))), reduce(map(t, f(x)(exp(x))), "sum", "dimension"), f(x,y)(x / y))
return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()),
Reduce.Aggregator.sum,