summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 11:40:14 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 11:40:14 +0100
commit296340ac996edac09a4f53997ae1a8a803d302c1 (patch)
treef7abfd7b4c30529bbdbc03334f523fcbebdd27e2 /vespajlib
parent69e4f6bf072d8ebfb12761c450f2bdacf86e226c (diff)
Add additional ONNX operations
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java2
3 files changed, 25 insertions, 1 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 47b066b15a6..4ec1a5a234b 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2063,6 +2063,21 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.ScalarFunctions$LeakyRelu": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "java.util.function.DoubleUnaryOperator"
+ ],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public double applyAsDouble(double)",
+ "public java.lang.String toString()"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.ScalarFunctions$Less": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -2429,6 +2444,7 @@
"public static java.util.function.DoubleUnaryOperator relu()",
"public static java.util.function.DoubleUnaryOperator rsqrt()",
"public static java.util.function.DoubleUnaryOperator selu()",
+ "public static java.util.function.DoubleUnaryOperator leakyrelu()",
"public static java.util.function.DoubleUnaryOperator sin()",
"public static java.util.function.DoubleUnaryOperator sigmoid()",
"public static java.util.function.DoubleUnaryOperator sqrt()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index eabeb8905f7..e8e329cd75c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -47,6 +47,7 @@ public class ScalarFunctions {
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
+ public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
@@ -248,6 +249,13 @@ public class ScalarFunctions {
public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; }
}
+ public static class LeakyRelu implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.max(0.01 * operand, operand); }
+ @Override
+ public String toString() { return "f(a)(max(0.01*a, a))"; }
+ }
+
public static class Sin implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.sin(operand); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 1086d91da31..810651bbcfb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -39,7 +39,7 @@ public abstract class TensorFunction {
/**
* Evaluates this tensor.
*
- * @param context a context which must be passed to all nexted functions when evaluating
+ * @param context a context which must be passed to all nested functions when evaluating
*/
public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context);