diff options
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java | 47 |
1 files changed, 44 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 28f8c4e252f..3a64083c623 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -14,6 +14,8 @@ import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; + import java.nio.file.Paths; import java.util.List; import java.util.Map; @@ -32,17 +34,22 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { private final String tokenTypeIdsName; private final String outputName; private final double termScoreThreshold; + private final boolean useCustomReduce; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; @Inject public SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config) { + this(onnx, runtime, config, true); + } + SpladeEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, SpladeEmbedderConfig config, boolean useCustomReduce) { this.runtime = runtime; inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); tokenTypeIdsName = config.transformerTokenTypeIds(); termScoreThreshold = config.termScoreThreshold(); + this.useCustomReduce = useCustomReduce; var tokenizerPath = Paths.get(config.tokenizerPath().toString()); var builder = new HuggingFaceTokenizer.Builder() @@ -117,20 +124,54 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"), attentionMaskName, attentionMask.expand("d0"), tokenTypeIdsName, tokenTypeIds.expand("d0")); - Tensor spladeTensor = sparsify((IndexedTensor) evaluator.evaluate(inputs).get(outputName), tensorType); + IndexedTensor output = (IndexedTensor) evaluator.evaluate(inputs).get(outputName); + Tensor spladeTensor = useCustomReduce + ? sparsifyCustomReduce(output, tensorType) + : sparsifyReduce(output, tensorType); runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); return spladeTensor; } /** - * Sparsify the model output tensor. + * Sparsify the output tensor by applying a threshold on the log of the relu of the output. + * This uses generic tensor reduce+map, and is slightly slower than a custom unrolled variant. + * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size + * of the vocabulary + * @param tensorType the type of the destination tensor + * @return A mapped tensor with the terms from the vocab that has a score above the threshold + */ + private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) { + //Remove batch dim, batch size of 1 + Tensor output = modelOutput.reduce(Reduce.Aggregator.max, "d0", "d1"); + Tensor logOfRelu = output.map((x) -> Math.log(1 + (x > 0 ? x : 0))); + IndexedTensor vocab = (IndexedTensor) logOfRelu; + var builder = Tensor.Builder.of(tensorType); + long[] tokens = new long[1]; + for (int i = 0; i < vocab.size(); i++) { + var score = vocab.get(i); + if (score > termScoreThreshold) { + tokens[0] = i; + String term = tokenizer.decode(tokens); + builder.cell(). + label(tensorType.dimensions().get(0).name(), term) + .value(score); + } + } + return builder.build(); + } + + + + /** + * Sparsify the model output tensor.This uses an unrolled custom reduce and is 15-20% faster than the using + * generic tensor reduce. * * @param modelOutput the model output tensor of type tensorType * @param tensorType the type of the destination tensor * @return A mapped tensor with the terms from the vocab that has a score above the threshold */ - public Tensor sparsify(IndexedTensor modelOutput, TensorType tensorType) { + public Tensor sparsifyCustomReduce(IndexedTensor modelOutput, TensorType tensorType) { var builder = Tensor.Builder.of(tensorType); long[] shape = modelOutput.shape(); if(shape.length != 3) { |