summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-12-17 08:48:45 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2023-12-17 08:48:45 +0100
commit745a8db7a8eaea7aa53736a26d64e97543900343 (patch)
tree3dc56fe5f2b0d0a300cd24470912f98c8842985f /model-integration
parent56ff2f5e971a26d81cfe5cdbac65d856118820e4 (diff)
Allow mapped 1d tensor for embed expressions
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java22
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java4
2 files changed, 13 insertions, 13 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 891be44a5d2..4af7820274f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -1,3 +1,4 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
@@ -9,13 +10,14 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
-import com.yahoo.tensor.*;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
-
import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
/**
@@ -137,16 +139,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x)));
Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1");
IndexedTensor vocab = (IndexedTensor) maxReduced;
- Tensor.Builder sparseTensor = MappedTensor.Builder.of(tensorType);
+ var builder = Tensor.Builder.of(tensorType);
for(int i = 0; i < vocab.size(); i++) {
- var value = vocab.get(i);
- if (value > termScoreThreshold) {
- String t = tokenizer.decode(List.of((long) i));
- TensorAddress label = TensorAddress.of(List.of(t).toArray(new String[0]));
- sparseTensor.cell(label, value);
+ var score = vocab.get(i);
+ if (score > termScoreThreshold) {
+ String term = tokenizer.decode(List.of((long) i));
+ builder.cell().
+ label(tensorType.dimensions().get(0).name(), term)
+ .value(score);
}
}
- return sparseTensor.build();
+ return builder.build();
}
@@ -159,7 +162,6 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
}
return builder.build();
}
-
@Override
public void deconstruct() {
evaluator.close();
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index 0c49d75cbe0..e0d940ca5fe 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -9,10 +9,9 @@ import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-
import java.util.List;
+import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
@@ -68,7 +67,6 @@ public class SpladeEmbedderTest {
SpladeEmbedderConfig.Builder builder = new SpladeEmbedderConfig.Builder();
builder.tokenizerPath(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
- builder.transformerOutput("logits");
builder.termScoreThreshold(scoreThreshold);
builder.transformerGpuDevice(-1);
return new SpladeEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());