diff options
author | Lester Solbakken <lesters@oath.com> | 2018-01-22 13:25:32 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-01-22 13:35:58 +0100 |
commit | ce34b6dd37afdce666e3b0b058c524ef9ebb5ef6 (patch) | |
tree | 3076e6377c4a938cffdd24c2b460e11f5833f47c /vespajlib | |
parent | 4148debe89932119346b102a81164921af007d00 (diff) |
Add batch normalization test case
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java | 58 |
1 files changed, 55 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 4d1abf3978f..da473941463 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -24,12 +24,18 @@ public class ScalarFunctions { public static DoubleBinaryOperator add() { return new Add(); } public static DoubleBinaryOperator divide() { return new Divide(); } public static DoubleBinaryOperator equal() { return new Equal(); } + public static DoubleBinaryOperator max() { return new Max(); } + public static DoubleBinaryOperator min() { return new Min(); } public static DoubleBinaryOperator multiply() { return new Multiply(); } + public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); } + public static DoubleBinaryOperator subtract() { return new Subtract(); } public static DoubleUnaryOperator acos() { return new Acos(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator relu() { return new Relu(); } + public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } + public static DoubleUnaryOperator selu() { return new Selu(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } @@ -54,11 +60,18 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a==b)"; } } - public static class Exp implements DoubleUnaryOperator { + public static class Max implements DoubleBinaryOperator { @Override - public double applyAsDouble(double operand) { return Math.exp(operand); } + public double applyAsDouble(double left, double right) { return Math.max(left, right); } @Override - public String toString() { return "f(a)(exp(a))"; } + public String toString() { return "f(a,b)(max(a, b)"; } + } + + public static class Min implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return Math.min(left, right); } + @Override + public String toString() { return "f(a,b)(min(a, b)"; } } public static class Multiply implements DoubleBinaryOperator { @@ -75,6 +88,21 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a / b)"; } } + public static class SquaredDifference implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return (left - right) * (left - right); } + @Override + public String toString() { return "f(a,b)((a-b) * (a-b))"; } + } + + public static class Subtract implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left - right; } + @Override + public String toString() { return "f(a,b)(a - b)"; } + } + + // Unary operators ------------------------------------------------------------------------------ public static class Acos implements DoubleUnaryOperator { @@ -91,6 +119,13 @@ public class ScalarFunctions { public String toString() { return "f(a)(if(a < 0, exp(a)-1, a))"; } } + public static class Exp implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.exp(operand); } + @Override + public String toString() { return "f(a)(exp(a))"; } + } + public static class Relu implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return Math.max(operand, 0); } @@ -98,6 +133,23 @@ public class ScalarFunctions { public String toString() { return "f(a)(max(0, a))"; } } + public static class Selu implements DoubleUnaryOperator { + // See https://arxiv.org/abs/1706.02515 + private static final double scale = 1.0507009873554804934193349852946; + private static final double alpha = 1.6732632423543772848170429916717; + @Override + public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); } + @Override + public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1))", scale, alpha); } + } + + public static class Rsqrt implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); } + @Override + public String toString() { return "f(a)(1.0 / sqrt(a))"; } + } + public static class Sigmoid implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); } |