From 06bde5687b214a97c72e41ee40ac76ad837a3d7d Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 18 Jun 2020 19:25:35 +0200 Subject: Add erf (the error function) --- .../importer/onnx/GraphImporter.java | 2 +- .../importer/onnx/OnnxOperationsTestCase.java | 3 +++ searchlib/abi-spec.json | 2 ++ .../searchlib/rankingexpression/rule/Function.java | 3 +++ .../src/main/javacc/RankingExpressionParser.jj | 4 +++- vespajlib/abi-spec.json | 17 +++++++++++++++ .../yahoo/tensor/functions/ScalarFunctions.java | 25 ++++++++++++++++++++++ 7 files changed, 54 insertions(+), 2 deletions(-) diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index a6ce5e40ed3..3d36b1bfffc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -90,7 +90,7 @@ class GraphImporter { case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); - case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); // approximation until we have erf in backend. + case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.erf()); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "expand": return new Expand(modelName, nodeName, inputs); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 20d1891adb8..7b9868d71f5 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -67,6 +67,9 @@ public class OnnxOperationsTestCase { assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x)); assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f)); + x = evaluate("tensor(d0[7]):[-40.0, -0.5, -0.1, 0.0, 0.1, 0.5, 40.0]"); + assertEval("erf", x, evaluate("erf(x)", x)); + x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]"); assertEval("log", x, evaluate("log(x)", x)); assertEval("sqrt", x, evaluate("sqrt(x)", x)); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 6bce791914c..c22d906e2b2 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1024,6 +1024,7 @@ "public static final int SQRT", "public static final int TAN", "public static final int TANH", + "public static final int ERF", "public static final int ATAN2", "public static final int FMOD", "public static final int LDEXP", @@ -1373,6 +1374,7 @@ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function sqrt", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tan", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tanh", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function erf", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function atan2", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function fmod", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function ldexp", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index c3c1c371a68..99afb3b38d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.tensor.functions.ScalarFunctions; + import java.io.Serializable; import static java.lang.Math.*; @@ -36,6 +38,7 @@ public enum Function implements Serializable { sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, tan { public double evaluate(double x, double y) { return tan(x); } }, tanh { public double evaluate(double x, double y) { return tanh(x); } }, + erf { public double evaluate(double x, double y) { return ScalarFunctions.Erf.erf(x); } }, atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } }, fmod(2) { public double evaluate(double x, double y) { return x % y; } }, diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 8aa10bf7b34..5f27bbcbeee 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -115,6 +115,7 @@ TOKEN : | | | + | | | @@ -727,7 +728,8 @@ Function unaryFunctionName() : { } { return Function.square; } | { return Function.sqrt; } | { return Function.tan; } | - { return Function.tanh; } + { return Function.tanh; } | + { return Function.erf; } } Function binaryFunctionName() : { } diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index d9467a41f78..154b6871392 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2075,6 +2075,22 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.ScalarFunctions$Erf": { + "superClass": "java.lang.Object", + "interfaces": [ + "java.util.function.DoubleUnaryOperator" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void ()", + "public double applyAsDouble(double)", + "public java.lang.String toString()", + "public static double erf(double)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.ScalarFunctions$Exp": { "superClass": "java.lang.Object", "interfaces": [ @@ -2506,6 +2522,7 @@ "public static java.util.function.DoubleUnaryOperator square()", "public static java.util.function.DoubleUnaryOperator tan()", "public static java.util.function.DoubleUnaryOperator tanh()", + "public static java.util.function.DoubleUnaryOperator erf()", "public static java.util.function.DoubleUnaryOperator elu()", "public static java.util.function.DoubleUnaryOperator elu(double)", "public static java.util.function.DoubleUnaryOperator leakyrelu()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index d9204e24d68..c19b07cf96f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -50,6 +50,7 @@ public class ScalarFunctions { public static DoubleUnaryOperator square() { return new Square(); } public static DoubleUnaryOperator tan() { return new Tan(); } public static DoubleUnaryOperator tanh() { return new Tanh(); } + public static DoubleUnaryOperator erf() { return new Erf(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); } @@ -330,6 +331,30 @@ public class ScalarFunctions { public String toString() { return "f(a)(tanh(a))"; } } + public static class Erf implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return erf(operand); } + @Override + public String toString() { return "f(a)(erf(a))"; } + + // Use Horner's method + // From https://introcs.cs.princeton.edu/java/21function/ErrorFunction.java.html + public static double erf(double v) { + double t = 1.0 / (1.0 + 0.5 * Math.abs(v)); + double ans = 1 - t * Math.exp(-v*v - 1.26551223 + + t * ( 1.00002368 + + t * ( 0.37409196 + + t * ( 0.09678418 + + t * (-0.18628806 + + t * ( 0.27886807 + + t * (-1.13520398 + + t * ( 1.48851587 + + t * (-0.82215223 + + t * ( 0.17087277)))))))))); + if (v >= 0) return ans; + else return -ans; + } + } // Variable-length operators ----------------------------------------------------------------------------- -- cgit v1.2.3