diff options
8 files changed, 314 insertions, 94 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java index 9e0a3a0ba5c..92c930e16e0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java @@ -61,5 +61,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); } + } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 7c5e8912e49..5daf74a9723 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -27,6 +27,7 @@ public class EmbedExpression extends Expression { private final Embedder embedder; private final String embedderId; + private final List<String> embedderArguments; /** The destination the embedding will be written to on the form [schema name].[field name] */ private String destination; @@ -34,22 +35,23 @@ public class EmbedExpression extends Expression { /** The target type we are embedding into. */ private TensorType targetType; - public EmbedExpression(Map<String, Embedder> embedders, String embedderId) { + public EmbedExpression(Map<String, Embedder> embedders, String embedderId, List<String> embedderArguments) { super(null); this.embedderId = embedderId; + this.embedderArguments = List.copyOf(embedderArguments); - boolean embedderIdProvided = embedderId != null && embedderId.length() > 0; + boolean embedderIdProvided = embedderId != null && !embedderId.isEmpty(); if (embedders.size() == 0) { throw new IllegalStateException("No embedders provided"); // should never happen } + else if (embedders.size() == 1 && ! embedderIdProvided) { + this.embedder = embedders.entrySet().stream().findFirst().get().getValue(); + } else if (embedders.size() > 1 && ! embedderIdProvided) { this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " + "Valid embedders are " + validEmbedders(embedders)); } - else if (embedders.size() == 1 && ! embedderIdProvided) { - this.embedder = embedders.entrySet().stream().findFirst().get().getValue(); - } else if ( ! embedders.containsKey(embedderId)) { this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + "Valid embedders are " + validEmbedders(embedders)); @@ -91,17 +93,51 @@ public class EmbedExpression extends Expression { private Tensor embedArrayValue(ExecutionContext context) { var input = (Array<StringFieldValue>)context.getValue(); var builder = Tensor.Builder.of(targetType); + if (targetType.rank() == 2) + embedArrayValueToRank2Tensor(input, builder, context); + else + embedArrayValueToRank3Tensor(input, builder, context); + return builder.build(); + } + + private void embedArrayValueToRank2Tensor(Array<StringFieldValue> input, + Tensor.Builder builder, + ExecutionContext context) { + String mappedDimension = targetType.mappedSubtype().dimensions().get(0).name(); + String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name(); for (int i = 0; i < input.size(); i++) { Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context); for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) { Tensor.Cell cell = cells.next(); builder.cell() - .label(targetType.mappedSubtype().dimensions().get(0).name(), i) - .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().numericLabel(0)) + .label(mappedDimension, i) + .label(indexedDimension, cell.getKey().numericLabel(0)) + .value(cell.getValue()); + } + } + } + + private void embedArrayValueToRank3Tensor(Array<StringFieldValue> input, + Tensor.Builder builder, + ExecutionContext context) { + String outerMappedDimension = embedderArguments.get(0); + String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get(); + String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name(); + long indexedDimensionSize = targetType.indexedSubtype().dimensions().get(0).size().get(); + var innerType = new TensorType.Builder().mapped(innerMappedDimension).indexed(indexedDimension,indexedDimensionSize).build(); + int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension); + int indexedDimensionIndex = innerType.indexOfDimensionAsInt(indexedDimension); + for (int i = 0; i < input.size(); i++) { + Tensor tensor = embed(input.get(i).getString(), innerType, context); + for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) { + Tensor.Cell cell = cells.next(); + builder.cell() + .label(outerMappedDimension, i) + .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)) + .label(indexedDimension, cell.getKey().numericLabel(indexedDimensionIndex)) .value(cell.getValue()); } } - return builder.build(); } private Tensor embed(String input, TensorType targetType, ExecutionContext context) { @@ -120,7 +156,17 @@ public class EmbedExpression extends Expression { targetType = toTargetTensor(context.getInputType(this, outputField)); if ( ! validTarget(targetType)) throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," + - "an array of dense 1d tensors, or a mixed 2d tensor"); + "an array of dense 1d tensors, or a mixed 2d or 3d tensor"); + if (targetType.rank() == 3) { + if (embedderArguments.size() != 1) + throw new VerificationException(this, "When the embedding target field is a 3d tensor " + + "the name of the tensor dimension that corresponds to the input array elements must " + + "be given as a second argument to embed, e.g: ... | embed colbert paragraph | ..."); + if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) + throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " + + "is not a sparse dimension of the target type " + targetType); + } + context.setValueType(createdOutputType()); } @@ -137,11 +183,12 @@ public class EmbedExpression extends Expression { } private boolean validTarget(TensorType target) { - if (target.dimensions().size() == 1) //indexed or mapped 1d tensor + if (target.rank() == 1) // indexed or mapped 1d tensor return true; - if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 - && target.mappedSubtype().rank() == 1) - return true; //mixed mapped-indexed 2d tensor + if (target.rank() == 2 && target.indexedSubtype().rank() == 1) + return true; // mixed 2d tensor + if (target.rank() == 3 && target.indexedSubtype().rank() == 1) + return true; // mixed 3d tensor return false; } diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index 42bbd26cee6..a3b4039408a 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -37,7 +37,6 @@ import com.yahoo.language.Linguistics; /** * @author Simon Thoresen Hult - * @version $Id$ */ public class IndexingParser { @@ -386,11 +385,16 @@ Expression echoExp() : { } Expression embedExp() : { - String val = ""; + String embedderId = ""; + String embedderArgument; + List<String> embedderArguments = new ArrayList<String>(); } { - ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] ) - { return new EmbedExpression(embedders, val); } + ( + <EMBED> [ LOOKAHEAD(2) embedderId = identifier() ] + ( LOOKAHEAD(2) embedderArgument = identifier() { embedderArguments.add(embedderArgument); } )* + ) + { return new EmbedExpression(embedders, embedderId, embedderArguments); } } Expression exactExp() : { } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 6206c2efe7a..7fe55b738df 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -181,7 +181,7 @@ public class ScriptTestCase { Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); Map<String, Embedder> embedder = Map.of( - "emb1", new MockEmbedder("myDocument.myTensor") + "emb1", new MockIndexedEmbedder("myDocument.myTensor") ); testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "input text", "[105, 110, 112, 117]"); @@ -193,8 +193,8 @@ public class ScriptTestCase { null, null); Map<String, Embedder> embedders = Map.of( - "emb1", new MockEmbedder("myDocument.myTensor"), - "emb2", new MockEmbedder("myDocument.myTensor", 1) + "emb1", new MockIndexedEmbedder("myDocument.myTensor"), + "emb2", new MockIndexedEmbedder("myDocument.myTensor", 1) ); testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "my input", "[109.0, 121.0, 32.0, 105.0]"); @@ -243,7 +243,7 @@ public class ScriptTestCase { @SuppressWarnings("unchecked") @Test public void testArrayEmbed() throws ParseException { - Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray")); + Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.myTensorArray")); TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", @@ -277,7 +277,7 @@ public class ScriptTestCase { @Test public void testArrayEmbedWithConcatenation() throws ParseException { - Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); + Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor")); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'", @@ -314,9 +314,10 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + /** Multiple paragraphs */ @Test - public void testArrayEmbedToMixedTensor() throws ParseException { - Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); + public void testArrayEmbedTo2dMixedTensor() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor")); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", @@ -348,17 +349,125 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("sec")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter))); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + // The two "passages" are [first, sec], the middle (d=1) token encodes those letters + assertEquals(Tensor.from(tensorType, + """ + { + {passage:0, token:0, d:0}: 101, + {passage:0, token:0, d:1}: 102, + {passage:0, token:0, d:2}: 103, + {passage:0, token:1, d:0}: 104, + {passage:0, token:1, d:1}: 105, + {passage:0, token:1, d:2}: 106, + {passage:0, token:2, d:0}: 113, + {passage:0, token:2, d:1}: 114, + {passage:0, token:2, d:2}: 115, + {passage:0, token:3, d:0}: 114, + {passage:0, token:3, d:1}: 115, + {passage:0, token:3, d:2}: 116, + {passage:0, token:4, d:0}: 115, + {passage:0, token:4, d:1}: 116, + {passage:0, token:4, d:2}: 117, + {passage:1, token:0, d:0}: 114, + {passage:1, token:0, d:1}: 115, + {passage:1, token:0, d:2}: 116, + {passage:1, token:1, d:0}: 100, + {passage:1, token:1, d:1}: 101, + {passage:1, token:1, d:2}: 102, + {passage:1, token:2, d:0}: 98, + {passage:1, token:2, d:1}: 99, + {passage:1, token:2, d:2}: 100 + } + """), + sparseTensor.getTensor().get()); + } + + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...", + e.getMessage()); + } + } + + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})", + e.getMessage()); + } + } + @SuppressWarnings("OptionalGetWithoutIsPresent") @Test public void testEmbedToSparseTensor() throws ParseException { - - Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true); + Embedder mappedEmbedder = new MockMappedEmbedder("myDocument.mySparseTensor", 0); Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder); TensorType tensorType = TensorType.fromSpec("tensor(t{})"); var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), + embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("text", DataType.STRING)); @@ -383,30 +492,23 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } - // An embedder which returns the char value of each letter in the input. */ - private static class MockEmbedder implements Embedder { - - private final String expectedDestination; - private final int addition; - - private final boolean mappedTensor; - - - public MockEmbedder(String expectedDestination) { - this(expectedDestination, 0, false); - } - public MockEmbedder(String expectedDestination, boolean mapped) { - this(expectedDestination, 0,mapped); + private void assertThrows(Runnable r, String msg) { + try { + r.run(); + fail(); + } catch (IllegalStateException e) { + assertEquals(e.getMessage(), msg); } + } - public MockEmbedder(String expectedDestination,int addition) { - this(expectedDestination, addition,false); - } + private static abstract class MockEmbedder implements Embedder { - public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) { + final String expectedDestination; + final int addition; + + public MockEmbedder(String expectedDestination, int addition) { this.expectedDestination = expectedDestination; this.addition = addition; - this.mappedTensor = mappedTensor; } @Override @@ -414,32 +516,84 @@ public class ScriptTestCase { return null; } + void verifyDestination(Embedder.Context context) { + assertEquals(expectedDestination, context.getDestination()); + } + + } + + /** An embedder which returns the char value of each letter in the input as a 1d indexed tensor. */ + private static class MockIndexedEmbedder extends MockEmbedder { + + public MockIndexedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockIndexedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + @Override public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { - assertEquals(expectedDestination, context.getDestination()); + verifyDestination(context); var b = Tensor.Builder.of(tensorType); - if (mappedTensor) { - for(int i = 0; i < text.length(); i++) { - var value = text.charAt(i) + addition; - b.cell(). - label(tensorType.dimensions().get(0).name(), text.charAt(i)) - .value(value); - } - } else { - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) - b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) + b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + return b.build(); + } - } + } + + /** An embedder which returns the char value of each letter in the input as a 1d mapped tensor. */ + private static class MockMappedEmbedder extends MockEmbedder { + + public MockMappedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockMappedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + verifyDestination(context); + var b = Tensor.Builder.of(tensorType); + for (int i = 0; i < text.length(); i++) + b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition); return b.build(); } + } - private void assertThrows(Runnable r, String msg) { - try { - r.run(); - fail(); - } catch (IllegalStateException e) { - assertEquals(e.getMessage(), msg); + /** + * An embedder which returns the char value of each letter in the input as a 2d mixed tensor where each input + * char becomes an indexed dimension containing input-1, input, input+1. + */ + private static class MockMixedEmbedder extends MockEmbedder { + + public MockMixedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockMixedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + verifyDestination(context); + var b = Tensor.Builder.of(tensorType); + String mappedDimension = tensorType.mappedSubtype().dimensions().get(0).name(); + String indexedDimension = tensorType.indexedSubtype().dimensions().get(0).name(); + for (int i = 0; i < text.length(); i++) { + for (int j = 0; j < 3; j++) { + b.cell().label(mappedDimension, i) + .label(indexedDimension, j) + .value(text.charAt(i) + addition + j - 1); + } + } + return b.build(); } } diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 8c39cc8c813..f76bfd28abf 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -18,7 +18,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; + import java.nio.file.Paths; import java.util.Map; import java.util.List; @@ -34,10 +34,14 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES * This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model. * * See col-bert-embedder.def for configurable parameters. + * * @author bergum */ @Beta public class ColBertEmbedder extends AbstractComponent implements Embedder { + + private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"; + private final Embedder.Runtime runtime; private final String inputIdsName; private final String attentionMaskName; @@ -117,7 +121,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { private void validateName(Map<String, TensorType> types, String name, String type) { if (!types.containsKey(name)) { throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + - "Model contains: " + String.join(",", types.keySet())); + "Model contains: " + String.join(",", types.keySet())); } } @@ -128,9 +132,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String text, Context context, TensorType tensorType) { - if (!verifyTensorType(tensorType)) { + if ( ! validTensorType(tensorType)) { throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. " + - "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); } if (context.getDestination().startsWith("query")) { return embedQuery(text, context, tensorType); @@ -196,7 +200,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); if (dims != result.shape()[2]) { throw new IllegalArgumentException("Token vector dimensionality does not" + - " match indexed dimensionality of " + dims); + " match indexed dimensionality of " + dims); } Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size()); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); @@ -213,13 +217,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), - attentionMaskName, attentionMaskTensor.expand("d0")); + attentionMaskName, attentionMaskTensor.expand("d0")); Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); IndexedTensor result = (IndexedTensor) tokenEmbeddings; Tensor contextualEmbeddings; - int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens. + int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens. if (tensorType.valueType() == TensorType.Value.INT8) { contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); } else { @@ -230,7 +234,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { - if(result.shape().length != 3) + if (result.shape().length != 3) throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); int size = type.indexedSubtype().dimensions().size(); if (size != 1) @@ -253,8 +257,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) { if (type.valueType() != TensorType.Value.INT8) - throw new IllegalArgumentException("Only a int8 tensor type can be" + - " the destination of bit packing"); + throw new IllegalArgumentException("Only a int8 tensor type can be the destination of bit packing"); if(result.shape().length != 3) throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); @@ -264,8 +267,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); int resultDimensionality = (int)result.shape()[2]; if (resultDimensionality != 8 * wantedDimensionality) { - throw new IllegalArgumentException("Not possible to pack " + resultDimensionality - + " + dimensions into " + wantedDimensionality + " dimensions"); + throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + + " + dimensions into " + wantedDimensionality + " dimensions"); } Tensor.Builder builder = Tensor.Builder.of(type); for (int token = 0; token < nTokens; token++) { @@ -302,9 +305,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return unpacker.evaluate(context).asTensor(); } - protected boolean verifyTensorType(TensorType target) { - return target.dimensions().size() == 2 && - target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1; + protected boolean validTensorType(TensorType target) { + return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1; } private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { @@ -316,5 +318,5 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } return builder.build(); } - private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"; + } diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java index 3a64083c623..58bd4deb659 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java @@ -25,9 +25,12 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES /** * A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels * are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0). + * + * @author bergum */ @Beta public class SpladeEmbedder extends AbstractComponent implements Embedder { + private final Embedder.Runtime runtime; private final String inputIdsName; private final String attentionMaskName; @@ -110,7 +113,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { public Tensor embed(String text, Context context, TensorType tensorType) { if (!verifyTensorType(tensorType)) { throw new IllegalArgumentException("Invalid splade embedder tensor destination. " + - "Wanted a mapped 1-d tensor, got " + tensorType); + "Wanted a mapped 1-d tensor, got " + tensorType); } var start = System.nanoTime(); @@ -132,17 +135,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { return spladeTensor; } - /** * Sparsify the output tensor by applying a threshold on the log of the relu of the output. * This uses generic tensor reduce+map, and is slightly slower than a custom unrolled variant. + * * @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size - * of the vocabulary + * of the vocabulary * @param tensorType the type of the destination tensor * @return A mapped tensor with the terms from the vocab that has a score above the threshold */ private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) { - //Remove batch dim, batch size of 1 + // Remove batch dim, batch size of 1 Tensor output = modelOutput.reduce(Reduce.Aggregator.max, "d0", "d1"); Tensor logOfRelu = output.map((x) -> Math.log(1 + (x > 0 ? x : 0))); IndexedTensor vocab = (IndexedTensor) logOfRelu; @@ -227,6 +230,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder { } return builder.build(); } + @Override public void deconstruct() { evaluator.close(); diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index 0cae94c372a..be75c4d3351 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -19,6 +19,9 @@ import java.util.Set; import static org.junit.Assert.*; import static org.junit.Assume.assumeTrue; +/** + * @author bergum + */ public class ColBertEmbedderTest { @Test @@ -67,23 +70,24 @@ public class ColBertEmbedderTest { assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); assertThrows(IllegalArgumentException.class, () -> { - //throws because int8 is not supported for query context + // throws because int8 is not supported for query context assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); }); + assertThrows(IllegalArgumentException.class, () -> { - //throws because 16 is less than model output (128) and we want float + // throws because 16 is less than model output (128) and we want float assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); }); assertThrows(IllegalArgumentException.class, () -> { - //throws because 128/8 does not fit into 15 + // throws because 128/8 does not fit into 15 assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); }); } @Test public void testInputTensorsWordPiece() { - //wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999] + // wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999] List<Long> tokens = List.of(2023L, 2003L, 1037L, 23032L, 999L); ColBertEmbedder.TransformerInput input = embedder.buildTransformerInput(tokens,10,true); assertEquals(10,input.inputIds().size()); @@ -100,7 +104,7 @@ public class ColBertEmbedderTest { @Test public void testInputTensorsSentencePiece() { - //Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711] + // Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711] // ! is mapped to 711 and is a punctuation character List<Long> tokens = List.of(903L, 83L, 10L, 41L, 1294L, 711L); ColBertEmbedder.TransformerInput input = multiLingualEmbedder.buildTransformerInput(tokens,10,true); @@ -109,7 +113,7 @@ public class ColBertEmbedderTest { assertEquals(List.of(0L, 3L, 903L, 83L, 10L, 41L, 1294L, 711L, 2L, 250001L),input.inputIds()); assertEquals(List.of(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 0L),input.attentionMask()); - //NO padding for document side and 711 (punctuation) is now filtered out + // NO padding for document side and 711 (punctuation) is now filtered out input = multiLingualEmbedder.buildTransformerInput(tokens,10,false); assertEquals(8,input.inputIds().size()); assertEquals(8,input.attentionMask().size()); @@ -156,12 +160,12 @@ public class ColBertEmbedderTest { sb.append(" "); } String text = sb.toString(); - Long now = System.currentTimeMillis(); + long now = System.currentTimeMillis(); int n = 1000; for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); } - Long elapsed = (System.currentTimeMillis() - now); + long elapsed = (System.currentTimeMillis() - now); System.out.println("Elapsed time: " + elapsed + " ms"); } @@ -170,7 +174,7 @@ public class ColBertEmbedderTest { Tensor result = embedder.embed(text, context, destType); assertEquals(destType,result.type()); MixedTensor mixedTensor = (MixedTensor) result; - if(context == queryContext) { + if (context == queryContext) { assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); } return result; @@ -200,12 +204,14 @@ public class ColBertEmbedderTest { static final ColBertEmbedder multiLingualEmbedder; static final Embedder.Context indexingContext; static final Embedder.Context queryContext; + static { indexingContext = new Embedder.Context("schema.indexing"); queryContext = new Embedder.Context("query(qt)"); embedder = getEmbedder(); multiLingualEmbedder = getMultiLingualEmbedder(); } + private static ColBertEmbedder getEmbedder() { String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json"; String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; @@ -235,4 +241,5 @@ public class ColBertEmbedderTest { return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index ac9dc4e4eca..d27c7cf0168 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -634,6 +634,7 @@ public interface Tensor { public Builder value(double cellValue) { return tensorBuilder.cell(addressBuilder.build(), cellValue); } + public Builder value(float cellValue) { return tensorBuilder.cell(addressBuilder.build(), cellValue); } |