From 1a25431ab58c752c7fc26dd8223bf1ba1079b24a Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Fri, 2 Feb 2024 12:28:53 +0100 Subject: Support embedding into rank 3 tensors --- .../vespa/indexinglanguage/ScriptTestCase.java | 250 +++++++++++++++++---- 1 file changed, 202 insertions(+), 48 deletions(-) (limited to 'indexinglanguage/src/test') diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 6206c2efe7a..7fe55b738df 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -181,7 +181,7 @@ public class ScriptTestCase { Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); Map embedder = Map.of( - "emb1", new MockEmbedder("myDocument.myTensor") + "emb1", new MockIndexedEmbedder("myDocument.myTensor") ); testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "input text", "[105, 110, 112, 117]"); @@ -193,8 +193,8 @@ public class ScriptTestCase { null, null); Map embedders = Map.of( - "emb1", new MockEmbedder("myDocument.myTensor"), - "emb2", new MockEmbedder("myDocument.myTensor", 1) + "emb1", new MockIndexedEmbedder("myDocument.myTensor"), + "emb2", new MockIndexedEmbedder("myDocument.myTensor", 1) ); testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "my input", "[109.0, 121.0, 32.0, 105.0]"); @@ -243,7 +243,7 @@ public class ScriptTestCase { @SuppressWarnings("unchecked") @Test public void testArrayEmbed() throws ParseException { - Map embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray")); + Map embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.myTensorArray")); TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", @@ -277,7 +277,7 @@ public class ScriptTestCase { @Test public void testArrayEmbedWithConcatenation() throws ParseException { - Map embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); + Map embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor")); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'", @@ -314,9 +314,10 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + /** Multiple paragraphs */ @Test - public void testArrayEmbedToMixedTensor() throws ParseException { - Map embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); + public void testArrayEmbedTo2dMixedTensor() throws ParseException { + Map embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor")); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", @@ -348,17 +349,125 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor() throws ParseException { + Map embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("sec")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter))); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + // The two "passages" are [first, sec], the middle (d=1) token encodes those letters + assertEquals(Tensor.from(tensorType, + """ + { + {passage:0, token:0, d:0}: 101, + {passage:0, token:0, d:1}: 102, + {passage:0, token:0, d:2}: 103, + {passage:0, token:1, d:0}: 104, + {passage:0, token:1, d:1}: 105, + {passage:0, token:1, d:2}: 106, + {passage:0, token:2, d:0}: 113, + {passage:0, token:2, d:1}: 114, + {passage:0, token:2, d:2}: 115, + {passage:0, token:3, d:0}: 114, + {passage:0, token:3, d:1}: 115, + {passage:0, token:3, d:2}: 116, + {passage:0, token:4, d:0}: 115, + {passage:0, token:4, d:1}: 116, + {passage:0, token:4, d:2}: 117, + {passage:1, token:0, d:0}: 114, + {passage:1, token:0, d:1}: 115, + {passage:1, token:0, d:2}: 116, + {passage:1, token:1, d:0}: 100, + {passage:1, token:1, d:1}: 101, + {passage:1, token:1, d:2}: 102, + {passage:1, token:2, d:0}: 98, + {passage:1, token:2, d:1}: 99, + {passage:1, token:2, d:2}: 100 + } + """), + sparseTensor.getTensor().get()); + } + + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException { + Map embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...", + e.getMessage()); + } + } + + /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */ + @Test + public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException { + Map embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])"); + var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType))); + + try { + expression.verify(new VerificationContext(adapter)); + fail("Expected exception"); + } + catch (VerificationException e) { + assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})", + e.getMessage()); + } + } + @SuppressWarnings("OptionalGetWithoutIsPresent") @Test public void testEmbedToSparseTensor() throws ParseException { - - Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true); + Embedder mappedEmbedder = new MockMappedEmbedder("myDocument.mySparseTensor", 0); Map embedders = Map.of("emb1",mappedEmbedder); TensorType tensorType = TensorType.fromSpec("tensor(t{})"); var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", - new SimpleLinguistics(), - embedders); + new SimpleLinguistics(), + embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("text", DataType.STRING)); @@ -383,30 +492,23 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } - // An embedder which returns the char value of each letter in the input. */ - private static class MockEmbedder implements Embedder { - - private final String expectedDestination; - private final int addition; - - private final boolean mappedTensor; - - - public MockEmbedder(String expectedDestination) { - this(expectedDestination, 0, false); - } - public MockEmbedder(String expectedDestination, boolean mapped) { - this(expectedDestination, 0,mapped); + private void assertThrows(Runnable r, String msg) { + try { + r.run(); + fail(); + } catch (IllegalStateException e) { + assertEquals(e.getMessage(), msg); } + } - public MockEmbedder(String expectedDestination,int addition) { - this(expectedDestination, addition,false); - } + private static abstract class MockEmbedder implements Embedder { - public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) { + final String expectedDestination; + final int addition; + + public MockEmbedder(String expectedDestination, int addition) { this.expectedDestination = expectedDestination; this.addition = addition; - this.mappedTensor = mappedTensor; } @Override @@ -414,32 +516,84 @@ public class ScriptTestCase { return null; } + void verifyDestination(Embedder.Context context) { + assertEquals(expectedDestination, context.getDestination()); + } + + } + + /** An embedder which returns the char value of each letter in the input as a 1d indexed tensor. */ + private static class MockIndexedEmbedder extends MockEmbedder { + + public MockIndexedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockIndexedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + @Override public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { - assertEquals(expectedDestination, context.getDestination()); + verifyDestination(context); var b = Tensor.Builder.of(tensorType); - if (mappedTensor) { - for(int i = 0; i < text.length(); i++) { - var value = text.charAt(i) + addition; - b.cell(). - label(tensorType.dimensions().get(0).name(), text.charAt(i)) - .value(value); - } - } else { - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) - b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) + b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + return b.build(); + } - } + } + + /** An embedder which returns the char value of each letter in the input as a 1d mapped tensor. */ + private static class MockMappedEmbedder extends MockEmbedder { + + public MockMappedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockMappedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + verifyDestination(context); + var b = Tensor.Builder.of(tensorType); + for (int i = 0; i < text.length(); i++) + b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition); return b.build(); } + } - private void assertThrows(Runnable r, String msg) { - try { - r.run(); - fail(); - } catch (IllegalStateException e) { - assertEquals(e.getMessage(), msg); + /** + * An embedder which returns the char value of each letter in the input as a 2d mixed tensor where each input + * char becomes an indexed dimension containing input-1, input, input+1. + */ + private static class MockMixedEmbedder extends MockEmbedder { + + public MockMixedEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } + + public MockMixedEmbedder(String expectedDestination, int addition) { + super(expectedDestination, addition); + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + verifyDestination(context); + var b = Tensor.Builder.of(tensorType); + String mappedDimension = tensorType.mappedSubtype().dimensions().get(0).name(); + String indexedDimension = tensorType.indexedSubtype().dimensions().get(0).name(); + for (int i = 0; i < text.length(); i++) { + for (int j = 0; j < 3; j++) { + b.cell().label(mappedDimension, i) + .label(indexedDimension, j) + .value(text.charAt(i) + addition + j - 1); + } + } + return b.build(); } } -- cgit v1.2.3