diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-10-05 12:19:47 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-10-05 12:19:47 +0000 |
commit | 54c1885ae2c0a672b67618e3871482c374d753ae (patch) | |
tree | d9c4ecb40ef0a6adfbdd7e7cc0153d58f30489eb /container-search | |
parent | 437bc6d89b13e1ac745d5e0ccfb47415d8b8bd2a (diff) |
unit test and handle infinity and NaN
Diffstat (limited to 'container-search')
3 files changed, 65 insertions, 10 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java index 0f87d0f0b52..dfba337c8e0 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java @@ -8,18 +8,20 @@ class LinearNormalizer extends Normalizer { } void normalize() { - double min = data[0]; - double max = data[0]; - for (int i = 1; i < size; i++) { - min = Math.min(min, data[i]); - max = Math.max(max, data[i]); + double min = Float.MAX_VALUE; + double max = -Float.MAX_VALUE; + for (int i = 0; i < size; i++) { + double val = data[i]; + if (val < Float.MAX_VALUE && val > -Float.MAX_VALUE) { + min = Math.min(min, data[i]); + max = Math.max(max, data[i]); + } } - min = Math.max(min, -Float.MAX_VALUE); - max = Math.min(max, Float.MAX_VALUE); double scale = 0.0; - double midpoint = (min + max) * 0.5; + double midpoint = 0.0; if (max > min) { scale = 1.0 / (max - min); + midpoint = (min + max) * 0.5; } for (int i = 0; i < size; i++) { double old = data[i]; diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java index 7862e54c32e..d3cdfc4bc78 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java @@ -9,7 +9,7 @@ class ReciprocalRankNormalizer extends Normalizer { ReciprocalRankNormalizer(String name, String input, int maxSize, double k) { super(name, input, maxSize); - this.k = k; + this.k = k; } static record IdxScore(int index, double score) {} @@ -18,7 +18,9 @@ class ReciprocalRankNormalizer extends Normalizer { if (size < 1) return; IdxScore[] temp = new IdxScore[size]; for (int i = 0; i < size; i++) { - temp[i] = new IdxScore(i, data[i]); + double val = data[i]; + if (Double.isNaN(val)) val = Double.NEGATIVE_INFINITY; + temp[i] = new IdxScore(i, val); } Arrays.sort(temp, (a, b) -> Double.compare(b.score, a.score)); for (int i = 0; i < size; i++) { diff --git a/container-search/src/test/java/com/yahoo/search/ranking/NormalizerTestCase.java b/container-search/src/test/java/com/yahoo/search/ranking/NormalizerTestCase.java index 1c10ad083a3..144ca0e8bae 100644 --- a/container-search/src/test/java/com/yahoo/search/ranking/NormalizerTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/ranking/NormalizerTestCase.java @@ -28,6 +28,31 @@ public class NormalizerTestCase { } @Test + void requireLinearHandlesInfinity() { + var n = new LinearNormalizer("foo", "bar", 10); + assertEquals(0, n.addInput(Double.NEGATIVE_INFINITY)); + assertEquals(1, n.addInput(1.0)); + assertEquals(2, n.addInput(9.0)); + assertEquals(3, n.addInput(5.0)); + assertEquals(4, n.addInput(Double.NaN)); + assertEquals(5, n.addInput(3.0)); + assertEquals(6, n.addInput(Double.POSITIVE_INFINITY)); + assertEquals(7, n.addInput(8.0)); + n.normalize(); + assertEquals(Double.NEGATIVE_INFINITY, n.getOutput(0)); + assertEquals(0.0, n.getOutput(1)); + assertEquals(1.0, n.getOutput(2)); + assertEquals(0.5, n.getOutput(3)); + assertEquals(Double.NaN, n.getOutput(4)); + assertEquals(0.25, n.getOutput(5)); + assertEquals(Double.POSITIVE_INFINITY, n.getOutput(6)); + assertEquals(0.875, n.getOutput(7)); + assertEquals("foo", n.name()); + assertEquals("bar", n.input()); + assertEquals("linear", n.normalizing()); + } + + @Test void requireReciprocalNormalizing() { var n = new ReciprocalRankNormalizer("foo", "bar", 10, 0.0); assertEquals(0, n.addInput(-4.1)); @@ -61,4 +86,30 @@ public class NormalizerTestCase { assertEquals("reciprocal-rank{k:4.2}", n.normalizing()); } + @Test + void requireReciprocalInfinities() { + var n = new ReciprocalRankNormalizer("foo", "bar", 10, 0.0); + assertEquals(0, n.addInput(Double.NEGATIVE_INFINITY)); + assertEquals(1, n.addInput(1.0)); + assertEquals(2, n.addInput(9.0)); + assertEquals(3, n.addInput(5.0)); + assertEquals(4, n.addInput(Double.NaN)); + assertEquals(5, n.addInput(3.0)); + assertEquals(6, n.addInput(Double.POSITIVE_INFINITY)); + assertEquals(7, n.addInput(8.0)); + n.normalize(); + assertEquals(1.0/7.0, n.getOutput(0)); + assertEquals(1.0/6.0, n.getOutput(1)); + assertEquals(1.0/2.0, n.getOutput(2)); + assertEquals(1.0/4.0, n.getOutput(3)); + assertEquals(1.0/8.0, n.getOutput(4)); + assertEquals(1.0/5.0, n.getOutput(5)); + assertEquals(1.0/1.0, n.getOutput(6)); + assertEquals(1.0/3.0, n.getOutput(7)); + n.normalize(); + assertEquals("foo", n.name()); + assertEquals("bar", n.input()); + assertEquals("reciprocal-rank{k:0.0}", n.normalizing()); + } + } |