diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 21:36:56 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 21:36:56 +0100 |
commit | 2a66025fe93b39f8d87201ceafe48345f7a4dc3f (patch) | |
tree | 9db3169e523ed9812b8047e00f5ce0f8ce4a753c /model-integration | |
parent | 254862ddf5d55923232abef00e6c2fff32bf463b (diff) |
Construct array right away instead of going via a single element list and the java stream api.
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 644b1ec538f..7a6d8a49a87 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -12,6 +12,7 @@ import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; @@ -139,10 +140,18 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { if (batch != 1) { throw new IllegalArgumentException("Batch size must be 1"); } - long sequenceLength = shape[1]; - long vocabSize = shape[2]; + if (shape[1] > Integer.MAX_VALUE) { + throw new IllegalArgumentException("sequenceLength=" + shape[1] + " larger than an int"); + } + if (shape[2] > Integer.MAX_VALUE) { + throw new IllegalArgumentException("vocabSize=" + shape[2] + " larger than an int"); + } + int sequenceLength = (int) shape[1]; + int vocabSize = (int) shape[2]; + String dimension = tensorType.dimensions().get(0).name(); //Iterate over the vocab dimension and find the max value for each sequence token + long [] tokens = new long[1]; for(int v = 0; v < vocabSize; v++) { double maxLogOfRelu = Double.MIN_VALUE; for(int s = 0; s < sequenceLength; s++) { @@ -153,9 +162,10 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { } } if (maxLogOfRelu > termScoreThreshold) { - String term = tokenizer.decode(List.of((long) v)); - builder.cell(). - label(tensorType.dimensions().get(0).name(), term) + tokens[0] = v; + String term = tokenizer.decode(tokens); + builder.cell() + .label(dimension, term) .value(maxLogOfRelu); } } |