aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-18 19:25:35 +0200
committerLester Solbakken <lesters@oath.com>2020-06-18 19:25:35 +0200
commit06bde5687b214a97c72e41ee40ac76ad837a3d7d (patch)
tree2d7ef9a8a2017d4f986744dfde82f8b671f4227b /vespajlib/src
parent5688a50eb92fc4459e51dccca45858aecca8264a (diff)
Add erf (the error function)
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java25
1 files changed, 25 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index d9204e24d68..c19b07cf96f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -50,6 +50,7 @@ public class ScalarFunctions {
public static DoubleUnaryOperator square() { return new Square(); }
public static DoubleUnaryOperator tan() { return new Tan(); }
public static DoubleUnaryOperator tanh() { return new Tanh(); }
+ public static DoubleUnaryOperator erf() { return new Erf(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
@@ -330,6 +331,30 @@ public class ScalarFunctions {
public String toString() { return "f(a)(tanh(a))"; }
}
+ public static class Erf implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return erf(operand); }
+ @Override
+ public String toString() { return "f(a)(erf(a))"; }
+
+ // Use Horner's method
+ // From https://introcs.cs.princeton.edu/java/21function/ErrorFunction.java.html
+ public static double erf(double v) {
+ double t = 1.0 / (1.0 + 0.5 * Math.abs(v));
+ double ans = 1 - t * Math.exp(-v*v - 1.26551223 +
+ t * ( 1.00002368 +
+ t * ( 0.37409196 +
+ t * ( 0.09678418 +
+ t * (-0.18628806 +
+ t * ( 0.27886807 +
+ t * (-1.13520398 +
+ t * ( 1.48851587 +
+ t * (-0.82215223 +
+ t * ( 0.17087277))))))))));
+ if (v >= 0) return ans;
+ else return -ans;
+ }
+ }
// Variable-length operators -----------------------------------------------------------------------------