aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-20 11:31:34 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-20 11:42:22 +0000
commit328fc7f2156d65f193499ff27a3ed85dce04b1f2 (patch)
treecc0868535f6eb5f1682942665310b6c6707e1962 /vespajlib
parent901237f0a48223b8971c56c95e0d7b41e3974d33 (diff)
new implementation of erf()
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java86
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java66
2 files changed, 137 insertions, 15 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index c19b07cf96f..3ee9e67cdd6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -3,7 +3,9 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
+import java.util.Comparator;
import java.util.List;
+import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
@@ -332,27 +334,81 @@ public class ScalarFunctions {
}
public static class Erf implements DoubleUnaryOperator {
+ static final Comparator<Double> byAbs = (x,y) -> Double.compare(Math.abs(x), Math.abs(y));
+
+ static double kummer(double a, double b, double z) {
+ PriorityQueue<Double> terms = new PriorityQueue<>(byAbs);
+ double term = 1.0;
+ long n = 0;
+ while (Math.abs(term) > Double.MIN_NORMAL) {
+ terms.add(term);
+ term *= (a+n);
+ term /= (b+n);
+ ++n;
+ term *= z;
+ term /= n;
+ }
+ double sum = terms.remove();
+ while (! terms.isEmpty()) {
+ sum += terms.remove();
+ terms.add(sum);
+ sum = terms.remove();
+ }
+ return sum;
+ }
+
+ static double approx_erfc(double x) {
+ double sq = x*x;
+ double mult = Math.exp(-sq) / (x * Math.sqrt(Math.PI));
+ double term = 1.0;
+ long n = 1;
+ double sum = 0.0;
+ while ((sum + term) != sum) {
+ double pterm = term;
+ sum += term;
+ term = 0.5 * pterm * n / sq;
+ if (term > pterm) {
+ sum -= 0.5 * pterm;
+ return sum*mult;
+ }
+ n += 2;
+ pterm = term;
+ sum -= term;
+ term = 0.5 * pterm * n / sq;
+ if (term > pterm) {
+ sum += 0.5 * pterm;
+ return sum*mult;
+ }
+ n += 2;
+ }
+ return sum*mult;
+ }
+
@Override
public double applyAsDouble(double operand) { return erf(operand); }
@Override
public String toString() { return "f(a)(erf(a))"; }
- // Use Horner's method
- // From https://introcs.cs.princeton.edu/java/21function/ErrorFunction.java.html
+ static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI);
+
public static double erf(double v) {
- double t = 1.0 / (1.0 + 0.5 * Math.abs(v));
- double ans = 1 - t * Math.exp(-v*v - 1.26551223 +
- t * ( 1.00002368 +
- t * ( 0.37409196 +
- t * ( 0.09678418 +
- t * (-0.18628806 +
- t * ( 0.27886807 +
- t * (-1.13520398 +
- t * ( 1.48851587 +
- t * (-0.82215223 +
- t * ( 0.17087277))))))))));
- if (v >= 0) return ans;
- else return -ans;
+ if (v < 0) {
+ return -erf(Math.abs(v));
+ }
+ if (v < 1.0e-10) {
+ // Just use the derivate when very near zero:
+ return v * nearZeroMultiplier;
+ }
+ if (v <= 1.0) {
+ // works best when v is small
+ return v * nearZeroMultiplier * kummer(0.5, 1.5, -v*v);
+ }
+ if (v < 4.3) {
+ // slower, but works with bigger v
+ return v * nearZeroMultiplier * Math.exp(-v*v) * kummer(1.0, 1.5, v*v);
+ }
+ // works only with "very big" v
+ return 1.0 - approx_erfc(v);
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java
new file mode 100644
index 00000000000..5890bac6c96
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java
@@ -0,0 +1,66 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.functions;
+
+import java.util.function.DoubleUnaryOperator;
+
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+
+public class ScalarFunctionsTestCase {
+
+ void expect_oddf(DoubleUnaryOperator foo, double input, double output) {
+ double res = foo.applyAsDouble(input);
+ assertEquals("apply("+foo+","+input+") -> ", output, res, 0.000000001);
+ input *= -1;
+ output *= -1;
+ res = foo.applyAsDouble(input);
+ assertEquals("apply("+foo+","+input+") -> "+res, output, res, 0.000000001);
+ }
+
+ @Test
+ public void testErrorFunction() {
+ var func = ScalarFunctions.erf();
+ // from wikipedia:
+ expect_oddf(func, 0.0, 0.0);
+ expect_oddf(func, 0.02, 0.022564575);
+ expect_oddf(func, 0.04, 0.045111106);
+ expect_oddf(func, 0.06, 0.067621594);
+ expect_oddf(func, 0.08, 0.090078126);
+ expect_oddf(func, 0.1, 0.112462916);
+ expect_oddf(func, 0.2, 0.222702589);
+ expect_oddf(func, 0.3, 0.328626759);
+ expect_oddf(func, 0.4, 0.428392355);
+ expect_oddf(func, 0.5, 0.520499878);
+ expect_oddf(func, 0.6, 0.603856091);
+ expect_oddf(func, 0.7, 0.677801194);
+ expect_oddf(func, 0.8, 0.742100965);
+ expect_oddf(func, 0.9, 0.796908212);
+ expect_oddf(func, 1.0, 0.842700793);
+ expect_oddf(func, 1.1, 0.88020507);
+ expect_oddf(func, 1.2, 0.910313978);
+ expect_oddf(func, 1.3, 0.934007945);
+ expect_oddf(func, 1.4, 0.95228512);
+ expect_oddf(func, 1.5, 0.966105146);
+ expect_oddf(func, 1.6, 0.976348383);
+ expect_oddf(func, 1.7, 0.983790459);
+ expect_oddf(func, 1.8, 0.989090502);
+ expect_oddf(func, 1.9, 0.992790429);
+ expect_oddf(func, 2.0, 0.995322265);
+ expect_oddf(func, 2.1, 0.997020533);
+ expect_oddf(func, 2.2, 0.998137154);
+ expect_oddf(func, 2.3, 0.998856823);
+ expect_oddf(func, 2.4, 0.999311486);
+ expect_oddf(func, 2.5, 0.999593048);
+ expect_oddf(func, 3.0, 0.99997791);
+ expect_oddf(func, 3.5, 0.999999257);
+ // from MPFR:
+ expect_oddf(func, 4.0, 0.99999998458);
+ expect_oddf(func, 4.2412109375, 0.9999999980);
+ expect_oddf(func, 4.2734375, 0.99999999849);
+ expect_oddf(func, 4.3203125, 0.9999999990);
+ expect_oddf(func, 5.0, 0.999999999998);
+ expect_oddf(func, 5.921875, 1.0);
+ }
+
+}