summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 23:39:52 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 23:39:52 +0100
commitac42a2d1ab601e0b494fdbee46e96fc92c00fe13 (patch)
tree7ecdb07411934e3cf97b8dd8e1c194c956e7343b /model-integration
parent2a66025fe93b39f8d87201ceafe48345f7a4dc3f (diff)
Since both value and log(value) are monotonically increasing for value >= 1,
we can just gather max(value) and do log at the end. Avoiding general Math.max which seems to have very costly NaN handling was quite benefiscal.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java16
1 files changed, 8 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 7a6d8a49a87..853009873a1 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -152,21 +152,21 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
String dimension = tensorType.dimensions().get(0).name();
//Iterate over the vocab dimension and find the max value for each sequence token
long [] tokens = new long[1];
- for(int v = 0; v < vocabSize; v++) {
- double maxLogOfRelu = Double.MIN_VALUE;
- for(int s = 0; s < sequenceLength; s++) {
+ for (int v = 0; v < vocabSize; v++) {
+ double maxValue = 0.0d;
+ for (int s = 0; s < sequenceLength; s++) {
double value = modelOutput.get(0, s, v); // batch, sequence, vocab
- double logOfRelu = Math.log(1 + Math.max(0, value));
- if(logOfRelu > maxLogOfRelu) {
- maxLogOfRelu = logOfRelu;
+ if (value > maxValue) {
+ maxValue = value;
}
}
- if (maxLogOfRelu > termScoreThreshold) {
+ double logOfRelu = Math.log(1 + maxValue);
+ if (logOfRelu > termScoreThreshold) {
tokens[0] = v;
String term = tokenizer.decode(tokens);
builder.cell()
.label(dimension, term)
- .value(maxLogOfRelu);
+ .value(logOfRelu);
}
}
return builder.build();