diff options
author | Lester Solbakken <lesters@oath.com> | 2019-12-05 09:09:43 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-12-05 09:09:43 +0100 |
commit | cd4e23a47c1993d5c9dbe17dfb23bdce3e037844 (patch) | |
tree | 7bf90e97261c246f4a3fe78b9401c50357fdac7f /vespajlib | |
parent | 7cd2264c56253a1e9745cb063b8868a5589c6b51 (diff) |
Add unit tests for ONNX operators (and fix some of the implementations)
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/abi-spec.json | 14 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java | 45 |
2 files changed, 45 insertions, 14 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 59474021de2..b6cf3547df2 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1985,6 +1985,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2075,6 +2076,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2271,6 +2273,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double, double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2437,22 +2440,25 @@ "public static java.util.function.DoubleUnaryOperator atan()", "public static java.util.function.DoubleUnaryOperator ceil()", "public static java.util.function.DoubleUnaryOperator cos()", - "public static java.util.function.DoubleUnaryOperator elu()", "public static java.util.function.DoubleUnaryOperator exp()", "public static java.util.function.DoubleUnaryOperator floor()", "public static java.util.function.DoubleUnaryOperator log()", "public static java.util.function.DoubleUnaryOperator neg()", "public static java.util.function.DoubleUnaryOperator reciprocal()", - "public static java.util.function.DoubleUnaryOperator relu()", "public static java.util.function.DoubleUnaryOperator rsqrt()", - "public static java.util.function.DoubleUnaryOperator selu()", - "public static java.util.function.DoubleUnaryOperator leakyrelu()", "public static java.util.function.DoubleUnaryOperator sin()", "public static java.util.function.DoubleUnaryOperator sigmoid()", "public static java.util.function.DoubleUnaryOperator sqrt()", "public static java.util.function.DoubleUnaryOperator square()", "public static java.util.function.DoubleUnaryOperator tan()", "public static java.util.function.DoubleUnaryOperator tanh()", + "public static java.util.function.DoubleUnaryOperator elu()", + "public static java.util.function.DoubleUnaryOperator elu(double)", + "public static java.util.function.DoubleUnaryOperator leakyrelu()", + "public static java.util.function.DoubleUnaryOperator leakyrelu(double)", + "public static java.util.function.DoubleUnaryOperator relu()", + "public static java.util.function.DoubleUnaryOperator selu()", + "public static java.util.function.DoubleUnaryOperator selu(double, double)", "public static java.util.function.Function random()", "public static java.util.function.Function equal(java.util.List)", "public static java.util.function.Function sum(java.util.List)" diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index e8e329cd75c..d9204e24d68 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -38,16 +38,12 @@ public class ScalarFunctions { public static DoubleUnaryOperator atan() { return new Atan(); } public static DoubleUnaryOperator ceil() { return new Ceil(); } public static DoubleUnaryOperator cos() { return new Cos(); } - public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator floor() { return new Floor(); } public static DoubleUnaryOperator log() { return new Log(); } public static DoubleUnaryOperator neg() { return new Neg(); } public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); } - public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } - public static DoubleUnaryOperator selu() { return new Selu(); } - public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); } public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } @@ -55,6 +51,14 @@ public class ScalarFunctions { public static DoubleUnaryOperator tan() { return new Tan(); } public static DoubleUnaryOperator tanh() { return new Tanh(); } + public static DoubleUnaryOperator elu() { return new Elu(); } + public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); } + public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); } + public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); } + public static DoubleUnaryOperator relu() { return new Relu(); } + public static DoubleUnaryOperator selu() { return new Selu(); } + public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); } + public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } @@ -191,10 +195,17 @@ public class ScalarFunctions { } public static class Elu implements DoubleUnaryOperator { + private final double alpha; + public Elu() { + this(1.0); + } + public Elu(double alpha) { + this.alpha = alpha; + } @Override - public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; } + public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; } @Override - public String toString() { return "f(a)(if(a < 0, exp(a)-1, a))"; } + public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; } } public static class Exp implements DoubleUnaryOperator { @@ -241,8 +252,15 @@ public class ScalarFunctions { public static class Selu implements DoubleUnaryOperator { // See https://arxiv.org/abs/1706.02515 - private static final double scale = 1.0507009873554804934193349852946; - private static final double alpha = 1.6732632423543772848170429916717; + private final double scale; // 1.0507009873554804934193349852946; + private final double alpha; // 1.6732632423543772848170429916717; + public Selu() { + this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717); + } + public Selu(double scale, double alpha) { + this.scale = scale; + this.alpha = alpha; + } @Override public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); } @Override @@ -250,10 +268,17 @@ public class ScalarFunctions { } public static class LeakyRelu implements DoubleUnaryOperator { + private final double alpha; + public LeakyRelu() { + this(0.01); + } + public LeakyRelu(double alpha) { + this.alpha = alpha; + } @Override - public double applyAsDouble(double operand) { return Math.max(0.01 * operand, operand); } + public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); } @Override - public String toString() { return "f(a)(max(0.01*a, a))"; } + public String toString() { return "f(a)(max(" + alpha + " * a, a))"; } } public static class Sin implements DoubleUnaryOperator { |