From ac42a2d1ab601e0b494fdbee46e96fc92c00fe13 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Thu, 18 Jan 2024 23:39:52 +0100 Subject: Since both value and log(value) are monotonically increasing for value >= 1, we can just gather max(value) and do log at the end. Avoiding general Math.max which seems to have very costly NaN handling was quite benefiscal. --- .../src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'model-integration/src') diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 7a6d8a49a87..853009873a1 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -152,21 +152,21 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token long [] tokens = new long[1]; - for(int v = 0; v < vocabSize; v++) { - double maxLogOfRelu = Double.MIN_VALUE; - for(int s = 0; s < sequenceLength; s++) { + for (int v = 0; v < vocabSize; v++) { + double maxValue = 0.0d; + for (int s = 0; s < sequenceLength; s++) { double value = modelOutput.get(0, s, v); // batch, sequence, vocab - double logOfRelu = Math.log(1 + Math.max(0, value)); - if(logOfRelu > maxLogOfRelu) { - maxLogOfRelu = logOfRelu; + if (value > maxValue) { + maxValue = value; } } - if (maxLogOfRelu > termScoreThreshold) { + double logOfRelu = Math.log(1 + maxValue); + if (logOfRelu > termScoreThreshold) { tokens[0] = v; String term = tokenizer.decode(tokens); builder.cell() .label(dimension, term) - .value(maxLogOfRelu); + .value(logOfRelu); } } return builder.build(); -- cgit v1.2.3