diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-17 08:48:45 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-17 08:48:45 +0100 |
commit | 745a8db7a8eaea7aa53736a26d64e97543900343 (patch) | |
tree | 3dc56fe5f2b0d0a300cd24470912f98c8842985f /model-integration | |
parent | 56ff2f5e971a26d81cfe5cdbac65d856118820e4 (diff) |
Allow mapped 1d tensor for embed expressions
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 22 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java | 4 |
2 files changed, 13 insertions, 13 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 891be44a5d2..4af7820274f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -1,3 +1,4 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; @@ -9,13 +10,14 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; -import com.yahoo.tensor.*; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import java.nio.file.Paths; import java.util.List; import java.util.Map; - import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; /** @@ -137,16 +139,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x))); Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1"); IndexedTensor vocab = (IndexedTensor) maxReduced; - Tensor.Builder sparseTensor = MappedTensor.Builder.of(tensorType); + var builder = Tensor.Builder.of(tensorType); for(int i = 0; i < vocab.size(); i++) { - var value = vocab.get(i); - if (value > termScoreThreshold) { - String t = tokenizer.decode(List.of((long) i)); - TensorAddress label = TensorAddress.of(List.of(t).toArray(new String[0])); - sparseTensor.cell(label, value); + var score = vocab.get(i); + if (score > termScoreThreshold) { + String term = tokenizer.decode(List.of((long) i)); + builder.cell(). + label(tensorType.dimensions().get(0).name(), term) + .value(score); } } - return sparseTensor.build(); + return builder.build(); } @@ -159,7 +162,6 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { } return builder.build(); } - @Override public void deconstruct() { evaluator.close(); diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 0c49d75cbe0..e0d940ca5fe 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -9,10 +9,9 @@ import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import org.junit.Test; - import java.util.List; +import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -68,7 +67,6 @@ public class SpladeEmbedderTest { SpladeEmbedderConfig.Builder builder = new SpladeEmbedderConfig.Builder(); builder.tokenizerPath(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); - builder.transformerOutput("logits"); builder.termScoreThreshold(scoreThreshold); builder.transformerGpuDevice(-1); return new SpladeEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); |