summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java3
-rw-r--r--searchlib/abi-spec.json2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java3
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj4
-rw-r--r--vespajlib/abi-spec.json17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java25
7 files changed, 54 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index a6ce5e40ed3..3d36b1bfffc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -90,7 +90,7 @@ class GraphImporter {
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble()));
- case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); // approximation until we have erf in backend.
+ case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.erf());
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "expand": return new Expand(modelName, nodeName, inputs);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 20d1891adb8..7b9868d71f5 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -67,6 +67,9 @@ public class OnnxOperationsTestCase {
assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x));
assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f));
+ x = evaluate("tensor(d0[7]):[-40.0, -0.5, -0.1, 0.0, 0.1, 0.5, 40.0]");
+ assertEval("erf", x, evaluate("erf(x)", x));
+
x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]");
assertEval("log", x, evaluate("log(x)", x));
assertEval("sqrt", x, evaluate("sqrt(x)", x));
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 6bce791914c..c22d906e2b2 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1024,6 +1024,7 @@
"public static final int SQRT",
"public static final int TAN",
"public static final int TANH",
+ "public static final int ERF",
"public static final int ATAN2",
"public static final int FMOD",
"public static final int LDEXP",
@@ -1373,6 +1374,7 @@
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function sqrt",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tan",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tanh",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function erf",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function atan2",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function fmod",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function ldexp",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index c3c1c371a68..99afb3b38d0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.tensor.functions.ScalarFunctions;
+
import java.io.Serializable;
import static java.lang.Math.*;
@@ -36,6 +38,7 @@ public enum Function implements Serializable {
sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
tan { public double evaluate(double x, double y) { return tan(x); } },
tanh { public double evaluate(double x, double y) { return tanh(x); } },
+ erf { public double evaluate(double x, double y) { return ScalarFunctions.Erf.erf(x); } },
atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } },
fmod(2) { public double evaluate(double x, double y) { return x % y; } },
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 8aa10bf7b34..5f27bbcbeee 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -115,6 +115,7 @@ TOKEN :
<SQRT: "sqrt"> |
<TAN: "tan"> |
<TANH: "tanh"> |
+ <ERF: "erf"> |
<ATAN2: "atan2"> |
<FMOD: "fmod"> |
@@ -727,7 +728,8 @@ Function unaryFunctionName() : { }
<SQUARE> { return Function.square; } |
<SQRT> { return Function.sqrt; } |
<TAN> { return Function.tan; } |
- <TANH> { return Function.tanh; }
+ <TANH> { return Function.tanh; } |
+ <ERF> { return Function.erf; }
}
Function binaryFunctionName() : { }
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index d9467a41f78..154b6871392 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2075,6 +2075,22 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.ScalarFunctions$Erf": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "java.util.function.DoubleUnaryOperator"
+ ],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public double applyAsDouble(double)",
+ "public java.lang.String toString()",
+ "public static double erf(double)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.ScalarFunctions$Exp": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -2506,6 +2522,7 @@
"public static java.util.function.DoubleUnaryOperator square()",
"public static java.util.function.DoubleUnaryOperator tan()",
"public static java.util.function.DoubleUnaryOperator tanh()",
+ "public static java.util.function.DoubleUnaryOperator erf()",
"public static java.util.function.DoubleUnaryOperator elu()",
"public static java.util.function.DoubleUnaryOperator elu(double)",
"public static java.util.function.DoubleUnaryOperator leakyrelu()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index d9204e24d68..c19b07cf96f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -50,6 +50,7 @@ public class ScalarFunctions {
public static DoubleUnaryOperator square() { return new Square(); }
public static DoubleUnaryOperator tan() { return new Tan(); }
public static DoubleUnaryOperator tanh() { return new Tanh(); }
+ public static DoubleUnaryOperator erf() { return new Erf(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
@@ -330,6 +331,30 @@ public class ScalarFunctions {
public String toString() { return "f(a)(tanh(a))"; }
}
+ public static class Erf implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return erf(operand); }
+ @Override
+ public String toString() { return "f(a)(erf(a))"; }
+
+ // Use Horner's method
+ // From https://introcs.cs.princeton.edu/java/21function/ErrorFunction.java.html
+ public static double erf(double v) {
+ double t = 1.0 / (1.0 + 0.5 * Math.abs(v));
+ double ans = 1 - t * Math.exp(-v*v - 1.26551223 +
+ t * ( 1.00002368 +
+ t * ( 0.37409196 +
+ t * ( 0.09678418 +
+ t * (-0.18628806 +
+ t * ( 0.27886807 +
+ t * (-1.13520398 +
+ t * ( 1.48851587 +
+ t * (-0.82215223 +
+ t * ( 0.17087277))))))))));
+ if (v >= 0) return ans;
+ else return -ans;
+ }
+ }
// Variable-length operators -----------------------------------------------------------------------------