diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 23:39:52 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 23:39:52 +0100 |
commit | ac42a2d1ab601e0b494fdbee46e96fc92c00fe13 (patch) | |
tree | 7ecdb07411934e3cf97b8dd8e1c194c956e7343b /model-integration | |
parent | 2a66025fe93b39f8d87201ceafe48345f7a4dc3f (diff) |
Since both value and log(value) are monotonically increasing for value >= 1,
we can just gather max(value) and do log at the end.
Avoiding general Math.max which seems to have very costly NaN handling was quite benefiscal.
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 7a6d8a49a87..853009873a1 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -152,21 +152,21 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; - for(int v = 0; v < vocabSize; v++) { - double maxLogOfRelu = Double.MIN_VALUE; - for(int s = 0; s < sequenceLength; s++) { + for (int v = 0; v < vocabSize; v++) { + double maxValue = 0.0d; + for (int s = 0; s < sequenceLength; s++) { double value = modelOutput.get(0, s, v); // batch, sequence, vocab - double logOfRelu = Math.log(1 + Math.max(0, value)); - if(logOfRelu > maxLogOfRelu) { - maxLogOfRelu = logOfRelu; + if (value > maxValue) { + maxValue = value; } } - if (maxLogOfRelu > termScoreThreshold) { + double logOfRelu = Math.log(1 + maxValue); + if (logOfRelu > termScoreThreshold) { tokens[0] = v; String term = tokenizer.decode(tokens); builder.cell() .label(dimension, term) - .value(maxLogOfRelu); + .value(logOfRelu); } } return builder.build(); |