diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-04-20 11:31:34 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-04-20 11:42:22 +0000 |
commit | 328fc7f2156d65f193499ff27a3ed85dce04b1f2 (patch) | |
tree | cc0868535f6eb5f1682942665310b6c6707e1962 | |
parent | 901237f0a48223b8971c56c95e0d7b41e3974d33 (diff) |
new implementation of erf()
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java | 86 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java | 66 |
2 files changed, 137 insertions, 15 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index c19b07cf96f..3ee9e67cdd6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -3,7 +3,9 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; +import java.util.Comparator; import java.util.List; +import java.util.PriorityQueue; import java.util.concurrent.ThreadLocalRandom; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; @@ -332,27 +334,81 @@ public class ScalarFunctions { } public static class Erf implements DoubleUnaryOperator { + static final Comparator<Double> byAbs = (x,y) -> Double.compare(Math.abs(x), Math.abs(y)); + + static double kummer(double a, double b, double z) { + PriorityQueue<Double> terms = new PriorityQueue<>(byAbs); + double term = 1.0; + long n = 0; + while (Math.abs(term) > Double.MIN_NORMAL) { + terms.add(term); + term *= (a+n); + term /= (b+n); + ++n; + term *= z; + term /= n; + } + double sum = terms.remove(); + while (! terms.isEmpty()) { + sum += terms.remove(); + terms.add(sum); + sum = terms.remove(); + } + return sum; + } + + static double approx_erfc(double x) { + double sq = x*x; + double mult = Math.exp(-sq) / (x * Math.sqrt(Math.PI)); + double term = 1.0; + long n = 1; + double sum = 0.0; + while ((sum + term) != sum) { + double pterm = term; + sum += term; + term = 0.5 * pterm * n / sq; + if (term > pterm) { + sum -= 0.5 * pterm; + return sum*mult; + } + n += 2; + pterm = term; + sum -= term; + term = 0.5 * pterm * n / sq; + if (term > pterm) { + sum += 0.5 * pterm; + return sum*mult; + } + n += 2; + } + return sum*mult; + } + @Override public double applyAsDouble(double operand) { return erf(operand); } @Override public String toString() { return "f(a)(erf(a))"; } - // Use Horner's method - // From https://introcs.cs.princeton.edu/java/21function/ErrorFunction.java.html + static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI); + public static double erf(double v) { - double t = 1.0 / (1.0 + 0.5 * Math.abs(v)); - double ans = 1 - t * Math.exp(-v*v - 1.26551223 + - t * ( 1.00002368 + - t * ( 0.37409196 + - t * ( 0.09678418 + - t * (-0.18628806 + - t * ( 0.27886807 + - t * (-1.13520398 + - t * ( 1.48851587 + - t * (-0.82215223 + - t * ( 0.17087277)))))))))); - if (v >= 0) return ans; - else return -ans; + if (v < 0) { + return -erf(Math.abs(v)); + } + if (v < 1.0e-10) { + // Just use the derivate when very near zero: + return v * nearZeroMultiplier; + } + if (v <= 1.0) { + // works best when v is small + return v * nearZeroMultiplier * kummer(0.5, 1.5, -v*v); + } + if (v < 4.3) { + // slower, but works with bigger v + return v * nearZeroMultiplier * Math.exp(-v*v) * kummer(1.0, 1.5, v*v); + } + // works only with "very big" v + return 1.0 - approx_erfc(v); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java new file mode 100644 index 00000000000..5890bac6c96 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ScalarFunctionsTestCase.java @@ -0,0 +1,66 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.functions; + +import java.util.function.DoubleUnaryOperator; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; + +public class ScalarFunctionsTestCase { + + void expect_oddf(DoubleUnaryOperator foo, double input, double output) { + double res = foo.applyAsDouble(input); + assertEquals("apply("+foo+","+input+") -> ", output, res, 0.000000001); + input *= -1; + output *= -1; + res = foo.applyAsDouble(input); + assertEquals("apply("+foo+","+input+") -> "+res, output, res, 0.000000001); + } + + @Test + public void testErrorFunction() { + var func = ScalarFunctions.erf(); + // from wikipedia: + expect_oddf(func, 0.0, 0.0); + expect_oddf(func, 0.02, 0.022564575); + expect_oddf(func, 0.04, 0.045111106); + expect_oddf(func, 0.06, 0.067621594); + expect_oddf(func, 0.08, 0.090078126); + expect_oddf(func, 0.1, 0.112462916); + expect_oddf(func, 0.2, 0.222702589); + expect_oddf(func, 0.3, 0.328626759); + expect_oddf(func, 0.4, 0.428392355); + expect_oddf(func, 0.5, 0.520499878); + expect_oddf(func, 0.6, 0.603856091); + expect_oddf(func, 0.7, 0.677801194); + expect_oddf(func, 0.8, 0.742100965); + expect_oddf(func, 0.9, 0.796908212); + expect_oddf(func, 1.0, 0.842700793); + expect_oddf(func, 1.1, 0.88020507); + expect_oddf(func, 1.2, 0.910313978); + expect_oddf(func, 1.3, 0.934007945); + expect_oddf(func, 1.4, 0.95228512); + expect_oddf(func, 1.5, 0.966105146); + expect_oddf(func, 1.6, 0.976348383); + expect_oddf(func, 1.7, 0.983790459); + expect_oddf(func, 1.8, 0.989090502); + expect_oddf(func, 1.9, 0.992790429); + expect_oddf(func, 2.0, 0.995322265); + expect_oddf(func, 2.1, 0.997020533); + expect_oddf(func, 2.2, 0.998137154); + expect_oddf(func, 2.3, 0.998856823); + expect_oddf(func, 2.4, 0.999311486); + expect_oddf(func, 2.5, 0.999593048); + expect_oddf(func, 3.0, 0.99997791); + expect_oddf(func, 3.5, 0.999999257); + // from MPFR: + expect_oddf(func, 4.0, 0.99999998458); + expect_oddf(func, 4.2412109375, 0.9999999980); + expect_oddf(func, 4.2734375, 0.99999999849); + expect_oddf(func, 4.3203125, 0.9999999990); + expect_oddf(func, 5.0, 0.999999999998); + expect_oddf(func, 5.921875, 1.0); + } + +} |