diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-10-10 15:46:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-10 15:46:25 +0200 |
commit | b6f022043c0e33c7462b5920dd5cd990a89f6727 (patch) | |
tree | bbd0a5c7be384ff40babddbcbd59a01611a487c6 /container-search | |
parent | 7564d59acf10f940ea244f1ec4163eb7d9ba893a (diff) | |
parent | e3401296f36818430e580d9522772f6d5ab2e43f (diff) |
Merge pull request #28757 from vespa-engine/arnej/add-normalizers
add Normalizer classes
Diffstat (limited to 'container-search')
4 files changed, 195 insertions, 0 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java new file mode 100644 index 00000000000..a3fb86bb9b5 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java @@ -0,0 +1,33 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +class LinearNormalizer extends Normalizer { + + LinearNormalizer(int maxSize) { + super(maxSize); + } + + void normalize() { + double min = Float.MAX_VALUE; + double max = -Float.MAX_VALUE; + for (int i = 0; i < size; i++) { + double val = data[i]; + if (val < Float.MAX_VALUE && val > -Float.MAX_VALUE) { + min = Math.min(min, data[i]); + max = Math.max(max, data[i]); + } + } + double scale = 0.0; + double midpoint = 0.0; + if (max > min) { + scale = 1.0 / (max - min); + midpoint = (min + max) * 0.5; + } + for (int i = 0; i < size; i++) { + double old = data[i]; + data[i] = 0.5 + scale * (old - midpoint); + } + } + + String normalizing() { return "linear"; } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java new file mode 100644 index 00000000000..269d4e6ed11 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java @@ -0,0 +1,23 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +abstract class Normalizer { + + protected final double[] data; + protected int size = 0; + + Normalizer(int maxSize) { + this.data = new double[maxSize]; + } + + int addInput(double value) { + data[size] = value; + return size++; + } + + double getOutput(int index) { return data[index]; } + + abstract void normalize(); + + abstract String normalizing(); +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java new file mode 100644 index 00000000000..6716485e343 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java @@ -0,0 +1,34 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.Arrays; + +class ReciprocalRankNormalizer extends Normalizer { + + private final double k; + + ReciprocalRankNormalizer(int maxSize, double k) { + super(maxSize); + this.k = k; + } + + static record IdxScore(int index, double score) {} + + void normalize() { + if (size < 1) return; + IdxScore[] temp = new IdxScore[size]; + for (int i = 0; i < size; i++) { + double val = data[i]; + if (Double.isNaN(val)) val = Double.NEGATIVE_INFINITY; + temp[i] = new IdxScore(i, val); + } + Arrays.sort(temp, (a, b) -> Double.compare(b.score, a.score)); + for (int i = 0; i < size; i++) { + int idx = temp[i].index; + double old = data[idx]; + data[idx] = 1.0 / (k + 1.0 + i); + } + } + + String normalizing() { return "reciprocal-rank{k:" + k + "}"; } +} diff --git a/container-search/src/test/java/com/yahoo/search/ranking/NormalizerTestCase.java b/container-search/src/test/java/com/yahoo/search/ranking/NormalizerTestCase.java new file mode 100644 index 00000000000..7373fb489f4 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/ranking/NormalizerTestCase.java @@ -0,0 +1,105 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author arnej + */ +public class NormalizerTestCase { + + @Test + void requireLinearNormalizing() { + var n = new LinearNormalizer(10); + assertEquals(0, n.addInput(-4.0)); + assertEquals(1, n.addInput(-1.0)); + assertEquals(2, n.addInput(-5.0)); + assertEquals(3, n.addInput(-3.0)); + n.normalize(); + assertEquals(0.0, n.getOutput(2)); + assertEquals(0.25, n.getOutput(0)); + assertEquals(0.5, n.getOutput(3)); + assertEquals(1.0, n.getOutput(1)); + assertEquals("linear", n.normalizing()); + } + + @Test + void requireLinearHandlesInfinity() { + var n = new LinearNormalizer(10); + assertEquals(0, n.addInput(Double.NEGATIVE_INFINITY)); + assertEquals(1, n.addInput(1.0)); + assertEquals(2, n.addInput(9.0)); + assertEquals(3, n.addInput(5.0)); + assertEquals(4, n.addInput(Double.NaN)); + assertEquals(5, n.addInput(3.0)); + assertEquals(6, n.addInput(Double.POSITIVE_INFINITY)); + assertEquals(7, n.addInput(8.0)); + n.normalize(); + assertEquals(Double.NEGATIVE_INFINITY, n.getOutput(0)); + assertEquals(0.0, n.getOutput(1)); + assertEquals(1.0, n.getOutput(2)); + assertEquals(0.5, n.getOutput(3)); + assertEquals(Double.NaN, n.getOutput(4)); + assertEquals(0.25, n.getOutput(5)); + assertEquals(Double.POSITIVE_INFINITY, n.getOutput(6)); + assertEquals(0.875, n.getOutput(7)); + assertEquals("linear", n.normalizing()); + } + + @Test + void requireReciprocalNormalizing() { + var n = new ReciprocalRankNormalizer(10, 0.0); + assertEquals(0, n.addInput(-4.1)); + assertEquals(1, n.addInput(11.0)); + assertEquals(2, n.addInput(-50.0)); + assertEquals(3, n.addInput(-3.0)); + n.normalize(); + assertEquals(0.25, n.getOutput(2)); + assertEquals(0.3333333, n.getOutput(0), 0.00001); + assertEquals(0.5, n.getOutput(3)); + assertEquals(1.0, n.getOutput(1)); + assertEquals("reciprocal-rank{k:0.0}", n.normalizing()); + } + + @Test + void requireReciprocalNormalizingWithK() { + var n = new ReciprocalRankNormalizer(10, 4.2); + assertEquals(0, n.addInput(-4.1)); + assertEquals(1, n.addInput(11.0)); + assertEquals(2, n.addInput(-50.0)); + assertEquals(3, n.addInput(-3.0)); + n.normalize(); + assertEquals(1.0/8.2, n.getOutput(2)); + assertEquals(1.0/7.2, n.getOutput(0), 0.00001); + assertEquals(1.0/6.2, n.getOutput(3)); + assertEquals(1.0/5.2, n.getOutput(1)); + assertEquals("reciprocal-rank{k:4.2}", n.normalizing()); + } + + @Test + void requireReciprocalInfinities() { + var n = new ReciprocalRankNormalizer(10, 0.0); + assertEquals(0, n.addInput(Double.NEGATIVE_INFINITY)); + assertEquals(1, n.addInput(1.0)); + assertEquals(2, n.addInput(9.0)); + assertEquals(3, n.addInput(5.0)); + assertEquals(4, n.addInput(Double.NaN)); + assertEquals(5, n.addInput(3.0)); + assertEquals(6, n.addInput(Double.POSITIVE_INFINITY)); + assertEquals(7, n.addInput(8.0)); + n.normalize(); + assertEquals(1.0/7.0, n.getOutput(0)); + assertEquals(1.0/6.0, n.getOutput(1)); + assertEquals(1.0/2.0, n.getOutput(2)); + assertEquals(1.0/4.0, n.getOutput(3)); + assertEquals(1.0/8.0, n.getOutput(4)); + assertEquals(1.0/5.0, n.getOutput(5)); + assertEquals(1.0/1.0, n.getOutput(6)); + assertEquals(1.0/3.0, n.getOutput(7)); + n.normalize(); + assertEquals("reciprocal-rank{k:0.0}", n.normalizing()); + } + +} |