diff options
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 7a6d8a49a87..853009873a1 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -152,21 +152,21 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; - for(int v = 0; v < vocabSize; v++) { - double maxLogOfRelu = Double.MIN_VALUE; - for(int s = 0; s < sequenceLength; s++) { + for (int v = 0; v < vocabSize; v++) { + double maxValue = 0.0d; + for (int s = 0; s < sequenceLength; s++) { double value = modelOutput.get(0, s, v); // batch, sequence, vocab - double logOfRelu = Math.log(1 + Math.max(0, value)); - if(logOfRelu > maxLogOfRelu) { - maxLogOfRelu = logOfRelu; + if (value > maxValue) { + maxValue = value; } } - if (maxLogOfRelu > termScoreThreshold) { + double logOfRelu = Math.log(1 + maxValue); + if (logOfRelu > termScoreThreshold) { tokens[0] = v; String term = tokenizer.decode(tokens); builder.cell() .label(dimension, term) - .value(maxLogOfRelu); + .value(logOfRelu); } } return builder.build(); |