diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-18 19:25:35 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-18 19:25:35 +0200 |
commit | 06bde5687b214a97c72e41ee40ac76ad837a3d7d (patch) | |
tree | 2d7ef9a8a2017d4f986744dfde82f8b671f4227b /model-integration | |
parent | 5688a50eb92fc4459e51dccca45858aecca8264a (diff) |
Add erf (the error function)
Diffstat (limited to 'model-integration')
2 files changed, 4 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index a6ce5e40ed3..3d36b1bfffc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -90,7 +90,7 @@ class GraphImporter { case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); - case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); // approximation until we have erf in backend. + case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.erf()); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "expand": return new Expand(modelName, nodeName, inputs); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 20d1891adb8..7b9868d71f5 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -67,6 +67,9 @@ public class OnnxOperationsTestCase { assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x)); assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f)); + x = evaluate("tensor(d0[7]):[-40.0, -0.5, -0.1, 0.0, 0.1, 0.5, 40.0]"); + assertEval("erf", x, evaluate("erf(x)", x)); + x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]"); assertEval("log", x, evaluate("log(x)", x)); assertEval("sqrt", x, evaluate("sqrt(x)", x)); |