diff options
author | Lester Solbakken <lesters@oath.com> | 2019-11-22 11:40:14 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-11-22 11:40:14 +0100 |
commit | 296340ac996edac09a4f53997ae1a8a803d302c1 (patch) | |
tree | f7abfd7b4c30529bbdbc03334f523fcbebdd27e2 /vespajlib | |
parent | 69e4f6bf072d8ebfb12761c450f2bdacf86e226c (diff) |
Add additional ONNX operations
Diffstat (limited to 'vespajlib')
3 files changed, 25 insertions, 1 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 47b066b15a6..4ec1a5a234b 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2063,6 +2063,21 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.ScalarFunctions$LeakyRelu": { + "superClass": "java.lang.Object", + "interfaces": [ + "java.util.function.DoubleUnaryOperator" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>()", + "public double applyAsDouble(double)", + "public java.lang.String toString()" + ], + "fields": [] + }, "com.yahoo.tensor.functions.ScalarFunctions$Less": { "superClass": "java.lang.Object", "interfaces": [ @@ -2429,6 +2444,7 @@ "public static java.util.function.DoubleUnaryOperator relu()", "public static java.util.function.DoubleUnaryOperator rsqrt()", "public static java.util.function.DoubleUnaryOperator selu()", + "public static java.util.function.DoubleUnaryOperator leakyrelu()", "public static java.util.function.DoubleUnaryOperator sin()", "public static java.util.function.DoubleUnaryOperator sigmoid()", "public static java.util.function.DoubleUnaryOperator sqrt()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index eabeb8905f7..e8e329cd75c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -47,6 +47,7 @@ public class ScalarFunctions { public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } public static DoubleUnaryOperator selu() { return new Selu(); } + public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); } public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } @@ -248,6 +249,13 @@ public class ScalarFunctions { public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; } } + public static class LeakyRelu implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.max(0.01 * operand, operand); } + @Override + public String toString() { return "f(a)(max(0.01*a, a))"; } + } + public static class Sin implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return Math.sin(operand); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 1086d91da31..810651bbcfb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -39,7 +39,7 @@ public abstract class TensorFunction { /** * Evaluates this tensor. * - * @param context a context which must be passed to all nexted functions when evaluating + * @param context a context which must be passed to all nested functions when evaluating */ public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); |