summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-01-22 13:25:32 +0100
committerLester Solbakken <lesters@oath.com>2018-01-22 13:35:58 +0100
commitce34b6dd37afdce666e3b0b058c524ef9ebb5ef6 (patch)
tree3076e6377c4a938cffdd24c2b460e11f5833f47c /vespajlib
parent4148debe89932119346b102a81164921af007d00 (diff)
Add batch normalization test case
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java58
1 files changed, 55 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 4d1abf3978f..da473941463 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -24,12 +24,18 @@ public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
+ public static DoubleBinaryOperator max() { return new Max(); }
+ public static DoubleBinaryOperator min() { return new Min(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
+ public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
+ public static DoubleBinaryOperator subtract() { return new Subtract(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator relu() { return new Relu(); }
+ public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
+ public static DoubleUnaryOperator selu() { return new Selu(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
@@ -54,11 +60,18 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a==b)"; }
}
- public static class Exp implements DoubleUnaryOperator {
+ public static class Max implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double operand) { return Math.exp(operand); }
+ public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@Override
- public String toString() { return "f(a)(exp(a))"; }
+ public String toString() { return "f(a,b)(max(a, b)"; }
+ }
+
+ public static class Min implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return Math.min(left, right); }
+ @Override
+ public String toString() { return "f(a,b)(min(a, b)"; }
}
public static class Multiply implements DoubleBinaryOperator {
@@ -75,6 +88,21 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a / b)"; }
}
+ public static class SquaredDifference implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return (left - right) * (left - right); }
+ @Override
+ public String toString() { return "f(a,b)((a-b) * (a-b))"; }
+ }
+
+ public static class Subtract implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left - right; }
+ @Override
+ public String toString() { return "f(a,b)(a - b)"; }
+ }
+
+
// Unary operators ------------------------------------------------------------------------------
public static class Acos implements DoubleUnaryOperator {
@@ -91,6 +119,13 @@ public class ScalarFunctions {
public String toString() { return "f(a)(if(a < 0, exp(a)-1, a))"; }
}
+ public static class Exp implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
+ @Override
+ public String toString() { return "f(a)(exp(a))"; }
+ }
+
public static class Relu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.max(operand, 0); }
@@ -98,6 +133,23 @@ public class ScalarFunctions {
public String toString() { return "f(a)(max(0, a))"; }
}
+ public static class Selu implements DoubleUnaryOperator {
+ // See https://arxiv.org/abs/1706.02515
+ private static final double scale = 1.0507009873554804934193349852946;
+ private static final double alpha = 1.6732632423543772848170429916717;
+ @Override
+ public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
+ @Override
+ public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1))", scale, alpha); }
+ }
+
+ public static class Rsqrt implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
+ @Override
+ public String toString() { return "f(a)(1.0 / sqrt(a))"; }
+ }
+
public static class Sigmoid implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); }