summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 21:36:56 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 21:36:56 +0100
commit2a66025fe93b39f8d87201ceafe48345f7a4dc3f (patch)
tree9db3169e523ed9812b8047e00f5ce0f8ce4a753c /model-integration
parent254862ddf5d55923232abef00e6c2fff32bf463b (diff)
Construct array right away instead of going via a single element list and the java stream api.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java20
1 files changed, 15 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 644b1ec538f..7a6d8a49a87 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -12,6 +12,7 @@ import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
@@ -139,10 +140,18 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
if (batch != 1) {
throw new IllegalArgumentException("Batch size must be 1");
}
- long sequenceLength = shape[1];
- long vocabSize = shape[2];
+ if (shape[1] > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int");
+ }
+ if (shape[2] > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int");
+ }
+ int sequenceLength = (int) shape[1];
+ int vocabSize = (int) shape[2];
+ String dimension = tensorType.dimensions().get(0).name();
//Iterate over the vocab dimension and find the max value for each sequence token
+ long [] tokens = new long[1];
for(int v = 0; v < vocabSize; v++) {
double maxLogOfRelu = Double.MIN_VALUE;
for(int s = 0; s < sequenceLength; s++) {
@@ -153,9 +162,10 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
}
}
if (maxLogOfRelu > termScoreThreshold) {
- String term = tokenizer.decode(List.of((long) v));
- builder.cell().
- label(tensorType.dimensions().get(0).name(), term)
+ tokens[0] = v;
+ String term = tokenizer.decode(tokens);
+ builder.cell()
+ .label(dimension, term)
.value(maxLogOfRelu);
}
}