summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-12-05 09:09:43 +0100
committerLester Solbakken <lesters@oath.com>2019-12-05 09:09:43 +0100
commitcd4e23a47c1993d5c9dbe17dfb23bdce3e037844 (patch)
tree7bf90e97261c246f4a3fe78b9401c50357fdac7f /vespajlib
parent7cd2264c56253a1e9745cb063b8868a5589c6b51 (diff)
Add unit tests for ONNX operators (and fix some of the implementations)
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java45
2 files changed, 45 insertions, 14 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 59474021de2..b6cf3547df2 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1985,6 +1985,7 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(double)",
"public double applyAsDouble(double)",
"public java.lang.String toString()"
],
@@ -2075,6 +2076,7 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(double)",
"public double applyAsDouble(double)",
"public java.lang.String toString()"
],
@@ -2271,6 +2273,7 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(double, double)",
"public double applyAsDouble(double)",
"public java.lang.String toString()"
],
@@ -2437,22 +2440,25 @@
"public static java.util.function.DoubleUnaryOperator atan()",
"public static java.util.function.DoubleUnaryOperator ceil()",
"public static java.util.function.DoubleUnaryOperator cos()",
- "public static java.util.function.DoubleUnaryOperator elu()",
"public static java.util.function.DoubleUnaryOperator exp()",
"public static java.util.function.DoubleUnaryOperator floor()",
"public static java.util.function.DoubleUnaryOperator log()",
"public static java.util.function.DoubleUnaryOperator neg()",
"public static java.util.function.DoubleUnaryOperator reciprocal()",
- "public static java.util.function.DoubleUnaryOperator relu()",
"public static java.util.function.DoubleUnaryOperator rsqrt()",
- "public static java.util.function.DoubleUnaryOperator selu()",
- "public static java.util.function.DoubleUnaryOperator leakyrelu()",
"public static java.util.function.DoubleUnaryOperator sin()",
"public static java.util.function.DoubleUnaryOperator sigmoid()",
"public static java.util.function.DoubleUnaryOperator sqrt()",
"public static java.util.function.DoubleUnaryOperator square()",
"public static java.util.function.DoubleUnaryOperator tan()",
"public static java.util.function.DoubleUnaryOperator tanh()",
+ "public static java.util.function.DoubleUnaryOperator elu()",
+ "public static java.util.function.DoubleUnaryOperator elu(double)",
+ "public static java.util.function.DoubleUnaryOperator leakyrelu()",
+ "public static java.util.function.DoubleUnaryOperator leakyrelu(double)",
+ "public static java.util.function.DoubleUnaryOperator relu()",
+ "public static java.util.function.DoubleUnaryOperator selu()",
+ "public static java.util.function.DoubleUnaryOperator selu(double, double)",
"public static java.util.function.Function random()",
"public static java.util.function.Function equal(java.util.List)",
"public static java.util.function.Function sum(java.util.List)"
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index e8e329cd75c..d9204e24d68 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -38,16 +38,12 @@ public class ScalarFunctions {
public static DoubleUnaryOperator atan() { return new Atan(); }
public static DoubleUnaryOperator ceil() { return new Ceil(); }
public static DoubleUnaryOperator cos() { return new Cos(); }
- public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
public static DoubleUnaryOperator log() { return new Log(); }
public static DoubleUnaryOperator neg() { return new Neg(); }
public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
- public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
- public static DoubleUnaryOperator selu() { return new Selu(); }
- public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
@@ -55,6 +51,14 @@ public class ScalarFunctions {
public static DoubleUnaryOperator tan() { return new Tan(); }
public static DoubleUnaryOperator tanh() { return new Tanh(); }
+ public static DoubleUnaryOperator elu() { return new Elu(); }
+ public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
+ public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
+ public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); }
+ public static DoubleUnaryOperator relu() { return new Relu(); }
+ public static DoubleUnaryOperator selu() { return new Selu(); }
+ public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); }
+
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
@@ -191,10 +195,17 @@ public class ScalarFunctions {
}
public static class Elu implements DoubleUnaryOperator {
+ private final double alpha;
+ public Elu() {
+ this(1.0);
+ }
+ public Elu(double alpha) {
+ this.alpha = alpha;
+ }
@Override
- public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; }
+ public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; }
@Override
- public String toString() { return "f(a)(if(a < 0, exp(a)-1, a))"; }
+ public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; }
}
public static class Exp implements DoubleUnaryOperator {
@@ -241,8 +252,15 @@ public class ScalarFunctions {
public static class Selu implements DoubleUnaryOperator {
// See https://arxiv.org/abs/1706.02515
- private static final double scale = 1.0507009873554804934193349852946;
- private static final double alpha = 1.6732632423543772848170429916717;
+ private final double scale; // 1.0507009873554804934193349852946;
+ private final double alpha; // 1.6732632423543772848170429916717;
+ public Selu() {
+ this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717);
+ }
+ public Selu(double scale, double alpha) {
+ this.scale = scale;
+ this.alpha = alpha;
+ }
@Override
public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
@Override
@@ -250,10 +268,17 @@ public class ScalarFunctions {
}
public static class LeakyRelu implements DoubleUnaryOperator {
+ private final double alpha;
+ public LeakyRelu() {
+ this(0.01);
+ }
+ public LeakyRelu(double alpha) {
+ this.alpha = alpha;
+ }
@Override
- public double applyAsDouble(double operand) { return Math.max(0.01 * operand, operand); }
+ public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); }
@Override
- public String toString() { return "f(a)(max(0.01*a, a))"; }
+ public String toString() { return "f(a)(max(" + alpha + " * a, a))"; }
}
public static class Sin implements DoubleUnaryOperator {