diff options
4 files changed, 18 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java index 96554e91d38..038a6cb78c8 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java @@ -13,7 +13,6 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATI public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConfig.Producer { private final OnnxModelOptions onnxModelOptions; - private final ModelReference modelRef; private final ModelReference vocabRef; private final Integer maxTokens; @@ -21,7 +20,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf private final String transformerAttentionMask; private final String transformerTokenTypeIds; private final String transformerOutput; - private final Double termScoreThreshold; public SpladeEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { @@ -52,8 +50,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } - - @Override public void getConfig(SpladeEmbedderConfig.Builder b) { b.transformerModel(modelRef).tokenizerPath(vocabRef); diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 29399a38fa9..1a9caaa5ca1 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -119,7 +119,7 @@ public class EmbedExpression extends Expression { "Don't know what tensor type to embed into"); targetType = toTargetTensor(context.getInputType(this, outputField)); if ( ! validTarget(targetType)) - throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " + + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," + "an array of dense 1d tensors, or a mixed 2d tensor"); context.setValueType(createdOutputType()); } @@ -134,14 +134,14 @@ public class EmbedExpression extends Expression { if ( ! ( dataType instanceof TensorDataType)) throw new IllegalArgumentException("Expected a tensor data type but got " + dataType); return ((TensorDataType)dataType).getTensorType(); - } private boolean validTarget(TensorType target) { - if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1) - return true; - if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1) + if (target.dimensions().size() == 1) //indexed or mapped 1d tensor return true; + if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 + && target.mappedSubtype().rank() == 1) + return true; //mixed mapped-indexed 2d tensor return false; } diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 891be44a5d2..4af7820274f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -1,3 +1,4 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; @@ -9,13 +10,14 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; -import com.yahoo.tensor.*; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; import java.nio.file.Paths; import java.util.List; import java.util.Map; - import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; /** @@ -137,16 +139,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { Tensor logOfRelu = modelOutput.map((x) -> Math.log(1 + Math.max(0, x))); Tensor maxReduced = logOfRelu.reduce(Reduce.Aggregator.max, "d1"); IndexedTensor vocab = (IndexedTensor) maxReduced; - Tensor.Builder sparseTensor = MappedTensor.Builder.of(tensorType); + var builder = Tensor.Builder.of(tensorType); for(int i = 0; i < vocab.size(); i++) { - var value = vocab.get(i); - if (value > termScoreThreshold) { - String t = tokenizer.decode(List.of((long) i)); - TensorAddress label = TensorAddress.of(List.of(t).toArray(new String[0])); - sparseTensor.cell(label, value); + var score = vocab.get(i); + if (score > termScoreThreshold) { + String term = tokenizer.decode(List.of((long) i)); + builder.cell(). + label(tensorType.dimensions().get(0).name(), term) + .value(score); } } - return sparseTensor.build(); + return builder.build(); } @@ -159,7 +162,6 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { } return builder.build(); } - @Override public void deconstruct() { evaluator.close(); diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 0c49d75cbe0..e0d940ca5fe 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -9,10 +9,9 @@ import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import org.junit.Test; - import java.util.List; +import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -68,7 +67,6 @@ public class SpladeEmbedderTest { SpladeEmbedderConfig.Builder builder = new SpladeEmbedderConfig.Builder(); builder.tokenizerPath(ModelReference.valueOf(vocabPath)); builder.transformerModel(ModelReference.valueOf(modelPath)); - builder.transformerOutput("logits"); builder.termScoreThreshold(scoreThreshold); builder.transformerGpuDevice(-1); return new SpladeEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); |