From 8791830cb01bd3cc6370472cd7918cb0c7e24e34 Mon Sep 17 00:00:00 2001 From: Jo Kristian Bergum Date: Fri, 15 Mar 2024 23:17:33 +0100 Subject: Attempt of supporting mapping array to mapped 2d tensor for sparse models --- .../expressions/EmbedExpression.java | 45 ++++++++++++++- .../vespa/indexinglanguage/ScriptTestCase.java | 67 ++++++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 1cfe8532c92..5d5410c2ef0 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -17,7 +17,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Objects; + /** * Embeds a string in a tensor space using the configured Embedder component @@ -95,7 +95,12 @@ public class EmbedExpression extends Expression { var input = (Array)context.getValue(); var builder = Tensor.Builder.of(targetType); if (targetType.rank() == 2) - embedArrayValueToRank2Tensor(input, builder, context); + if (targetType.indexedSubtype().rank() == 1) + embedArrayValueToRank2Tensor(input, builder, context); + else if(targetType.mappedSubtype().rank() == 2) + embedArrayValueToRank2MappedTensor(input, builder, context); + else + throw new IllegalArgumentException("Embedding an array into " + targetType + " is not supported"); else embedArrayValueToRank3Tensor(input, builder, context); return builder.build(); @@ -141,6 +146,27 @@ public class EmbedExpression extends Expression { } } + private void embedArrayValueToRank2MappedTensor(Array input, + Tensor.Builder builder, + ExecutionContext context) { + String outerMappedDimension = embedderArguments.get(0); + String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get(); + + var innerType = new TensorType.Builder(targetType.valueType()).mapped(innerMappedDimension).build(); + int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension); + + for (int i = 0; i < input.size(); i++) { + Tensor tensor = embed(input.get(i).getString(), innerType, context); + for (Iterator cells = tensor.cellIterator(); cells.hasNext(); ) { + Tensor.Cell cell = cells.next(); + builder.cell() + .label(outerMappedDimension, i) + .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)) + .value(cell.getValue()); + } + } + } + private Tensor embed(String input, TensorType targetType, ExecutionContext context) { return embedder.embed(input, new Embedder.Context(destination).setLanguage(context.getLanguage()).setEmbedderId(embedderId), @@ -156,8 +182,19 @@ public class EmbedExpression extends Expression { "Don't know what tensor type to embed into"); targetType = toTargetTensor(context.getInputType(this, outputField)); if ( ! validTarget(targetType)) - throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," + + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, " + "an array of dense 1d tensors, or a mixed 2d or 3d tensor"); + if (targetType.rank() == 2 && targetType.mappedSubtype().rank() == 2) { + if (embedderArguments.size() != 1) + throw new VerificationException(this, "When the embedding target field is a 2d mapped tensor " + + "the name of the tensor dimension that corresponds to the input array elements must " + + "be given as a second argument to embed, e.g: ... | embed splade paragraph | ..."); + if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) { + throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " + + "is not a sparse dimension of the target type " + targetType); + + } + } if (targetType.rank() == 3) { if (embedderArguments.size() != 1) throw new VerificationException(this, "When the embedding target field is a 3d tensor " + @@ -188,6 +225,8 @@ public class EmbedExpression extends Expression { return true; if (target.rank() == 2 && target.indexedSubtype().rank() == 1) return true; // mixed 2d tensor + if(target.rank() == 2 && target.mappedSubtype().rank() == 2) + return true; // mapped 2d tensor if (target.rank() == 3 && target.indexedSubtype().rank() == 1) return true; // mixed 3d tensor return false; diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 06ff6bc85b1..4e1eae2ed46 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -493,6 +493,73 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + /** Multiple paragraphs with sparse encoding (splade style) */ + @Test + public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws ParseException { + Map embedders = Map.of("emb1", new MockMappedEmbedder("myDocument.my2DSparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); + var expression = Expression.fromString("input myTextArray | embed emb1 doh | attribute 'my2DSparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("my2DSparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("The dimension 'doh' given to embed is not a sparse dimension of the target type tensor(passage{},token{})", + e.getMessage()); + } + } + + /** Multiple paragraphs with sparse encoding (splade style) */ + @Test + @SuppressWarnings("OptionalGetWithoutIsPresent") + public void testArrayEmbedTo2MappedTensor() throws ParseException { + Map embedders = Map.of("emb1", new MockMappedEmbedder("myDocument.my2DSparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})"); + var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'my2DSparseTensor'", + new SimpleLinguistics(), + embedders); + assertEquals("input myTextArray | embed emb1 passage | attribute my2DSparseTensor", expression.toString()); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + var tensorField = new Field("my2DSparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("abc")); + array.add(new StringFieldValue("cde")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter))); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("my2DSparseTensor")); + var sparse2DTensor = (TensorFieldValue)adapter.values.get("my2DSparseTensor"); + assertEquals(Tensor.from( + tensorType, + "tensor(passage{},token{}):" + + "{{passage:0,token:97}:97.0, " + + "{passage:0,token:98}:98.0, " + + "{passage:0,token:99}:99.0, " + + "{passage:1,token:100}:100.0, " + + "{passage:1,token:101}:101.0, " + + "{passage:1,token:99}:99.0}"), + sparse2DTensor.getTensor().get()); + } + + private void assertThrows(Runnable r, String msg) { try { r.run(); -- cgit v1.2.3