aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java4
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java22
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java4
4 files changed, 18 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
index 96554e91d38..038a6cb78c8 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
@@ -13,7 +13,6 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI
public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConfig.Producer {
private final OnnxModelOptions onnxModelOptions;
-
private final ModelReference modelRef;
private final ModelReference vocabRef;
private final Integer maxTokens;
@@ -21,7 +20,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf
private final String transformerAttentionMask;
private final String transformerTokenTypeIds;
private final String transformerOutput;
-
private final Double termScoreThreshold;
public SpladeEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
@@ -52,8 +50,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
-
-
@Override
public void getConfig(SpladeEmbedderConfig.Builder b) {
b.transformerModel(modelRef).tokenizerPath(vocabRef);
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 29399a38fa9..1a9caaa5ca1 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -119,7 +119,7 @@ public class EmbedExpression extends Expression {
"Don't know what tensor type to embed into");
targetType = toTargetTensor(context.getInputType(this, outputField));
if ( ! validTarget(targetType))
- throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " +
+ throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," +
"an array of dense 1d tensors, or a mixed 2d tensor");
context.setValueType(createdOutputType());
}
@@ -134,14 +134,14 @@ public class EmbedExpression extends Expression {
if ( ! ( dataType instanceof TensorDataType))
throw new IllegalArgumentException("Expected a tensor data type but got " + dataType);
return ((TensorDataType)dataType).getTensorType();
-
}
private boolean validTarget(TensorType target) {
- if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1)
- return true;
- if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1)
+ if (target.dimensions().size() == 1) //indexed or mapped 1d tensor
return true;
+ if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1
+ && target.mappedSubtype().rank() == 1)
+ return true; //mixed mapped-indexed 2d tensor
return false;
}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 891be44a5d2..4af7820274f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -1,3 +1,4 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
@@ -9,13 +10,14 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
-import com.yahoo.tensor.*;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
-
import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
/**
@@ -137,16 +139,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x)));
Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1");
IndexedTensor vocab = (IndexedTensor) maxReduced;
- Tensor.Builder sparseTensor = MappedTensor.Builder.of(tensorType);
+ var builder = Tensor.Builder.of(tensorType);
for(int i = 0; i < vocab.size(); i++) {
- var value = vocab.get(i);
- if (value > termScoreThreshold) {
- String t = tokenizer.decode(List.of((long) i));
- TensorAddress label = TensorAddress.of(List.of(t).toArray(new String[0]));
- sparseTensor.cell(label, value);
+ var score = vocab.get(i);
+ if (score > termScoreThreshold) {
+ String term = tokenizer.decode(List.of((long) i));
+ builder.cell().
+ label(tensorType.dimensions().get(0).name(), term)
+ .value(score);
}
}
- return sparseTensor.build();
+ return builder.build();
}
@@ -159,7 +162,6 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
}
return builder.build();
}
-
@Override
public void deconstruct() {
evaluator.close();
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index 0c49d75cbe0..e0d940ca5fe 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -9,10 +9,9 @@ import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-
import java.util.List;
+import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
@@ -68,7 +67,6 @@ public class SpladeEmbedderTest {
SpladeEmbedderConfig.Builder builder = new SpladeEmbedderConfig.Builder();
builder.tokenizerPath(ModelReference.valueOf(vocabPath));
builder.transformerModel(ModelReference.valueOf(modelPath));
- builder.transformerOutput("logits");
builder.termScoreThreshold(scoreThreshold);
builder.transformerGpuDevice(-1);
return new SpladeEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());