summaryrefslogtreecommitdiffstats
path: root/indexinglanguage/src/test
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-02-02 12:28:53 +0100
committerJon Bratseth <bratseth@vespa.ai>2024-02-02 12:28:53 +0100
commit1a25431ab58c752c7fc26dd8223bf1ba1079b24a (patch)
tree954d7e2f3e43bb0636a6af7a93195a84e41e609b /indexinglanguage/src/test
parent2191193c6e107eb68611ddb106e5f572bea32903 (diff)
Support embedding into rank 3 tensors
Diffstat (limited to 'indexinglanguage/src/test')
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java250
1 files changed, 202 insertions, 48 deletions
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 6206c2efe7a..7fe55b738df 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -181,7 +181,7 @@ public class ScriptTestCase {
Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
Map<String, Embedder> embedder = Map.of(
- "emb1", new MockEmbedder("myDocument.myTensor")
+ "emb1", new MockIndexedEmbedder("myDocument.myTensor")
);
testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder,
"input text", "[105, 110, 112, 117]");
@@ -193,8 +193,8 @@ public class ScriptTestCase {
null, null);
Map<String, Embedder> embedders = Map.of(
- "emb1", new MockEmbedder("myDocument.myTensor"),
- "emb2", new MockEmbedder("myDocument.myTensor", 1)
+ "emb1", new MockIndexedEmbedder("myDocument.myTensor"),
+ "emb2", new MockIndexedEmbedder("myDocument.myTensor", 1)
);
testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders,
"my input", "[109.0, 121.0, 32.0, 105.0]");
@@ -243,7 +243,7 @@ public class ScriptTestCase {
@SuppressWarnings("unchecked")
@Test
public void testArrayEmbed() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray"));
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.myTensorArray"));
TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'",
@@ -277,7 +277,7 @@ public class ScriptTestCase {
@Test
public void testArrayEmbedWithConcatenation() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'",
@@ -314,9 +314,10 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ /** Multiple paragraphs */
@Test
- public void testArrayEmbedToMixedTensor() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+ public void testArrayEmbedTo2dMixedTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'",
@@ -348,17 +349,125 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("sec"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter)));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ // The two "passages" are [first, sec], the middle (d=1) token encodes those letters
+ assertEquals(Tensor.from(tensorType,
+ """
+ {
+ {passage:0, token:0, d:0}: 101,
+ {passage:0, token:0, d:1}: 102,
+ {passage:0, token:0, d:2}: 103,
+ {passage:0, token:1, d:0}: 104,
+ {passage:0, token:1, d:1}: 105,
+ {passage:0, token:1, d:2}: 106,
+ {passage:0, token:2, d:0}: 113,
+ {passage:0, token:2, d:1}: 114,
+ {passage:0, token:2, d:2}: 115,
+ {passage:0, token:3, d:0}: 114,
+ {passage:0, token:3, d:1}: 115,
+ {passage:0, token:3, d:2}: 116,
+ {passage:0, token:4, d:0}: 115,
+ {passage:0, token:4, d:1}: 116,
+ {passage:0, token:4, d:2}: 117,
+ {passage:1, token:0, d:0}: 114,
+ {passage:1, token:0, d:1}: 115,
+ {passage:1, token:0, d:2}: 116,
+ {passage:1, token:1, d:0}: 100,
+ {passage:1, token:1, d:1}: 101,
+ {passage:1, token:1, d:2}: 102,
+ {passage:1, token:2, d:0}: 98,
+ {passage:1, token:2, d:1}: 99,
+ {passage:1, token:2, d:2}: 100
+ }
+ """),
+ sparseTensor.getTensor().get());
+ }
+
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType)));
+
+ try {
+ expression.verify(new VerificationContext(adapter));
+ fail("Expected exception");
+ }
+ catch (VerificationException e) {
+ assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...",
+ e.getMessage());
+ }
+ }
+
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType)));
+
+ try {
+ expression.verify(new VerificationContext(adapter));
+ fail("Expected exception");
+ }
+ catch (VerificationException e) {
+ assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})",
+ e.getMessage());
+ }
+ }
+
@SuppressWarnings("OptionalGetWithoutIsPresent")
@Test
public void testEmbedToSparseTensor() throws ParseException {
-
- Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true);
+ Embedder mappedEmbedder = new MockMappedEmbedder("myDocument.mySparseTensor", 0);
Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder);
TensorType tensorType = TensorType.fromSpec("tensor(t{})");
var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'",
- new SimpleLinguistics(),
- embedders);
+ new SimpleLinguistics(),
+ embedders);
SimpleTestAdapter adapter = new SimpleTestAdapter();
adapter.createField(new Field("text", DataType.STRING));
@@ -383,30 +492,23 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
- // An embedder which returns the char value of each letter in the input. */
- private static class MockEmbedder implements Embedder {
-
- private final String expectedDestination;
- private final int addition;
-
- private final boolean mappedTensor;
-
-
- public MockEmbedder(String expectedDestination) {
- this(expectedDestination, 0, false);
- }
- public MockEmbedder(String expectedDestination, boolean mapped) {
- this(expectedDestination, 0,mapped);
+ private void assertThrows(Runnable r, String msg) {
+ try {
+ r.run();
+ fail();
+ } catch (IllegalStateException e) {
+ assertEquals(e.getMessage(), msg);
}
+ }
- public MockEmbedder(String expectedDestination,int addition) {
- this(expectedDestination, addition,false);
- }
+ private static abstract class MockEmbedder implements Embedder {
- public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) {
+ final String expectedDestination;
+ final int addition;
+
+ public MockEmbedder(String expectedDestination, int addition) {
this.expectedDestination = expectedDestination;
this.addition = addition;
- this.mappedTensor = mappedTensor;
}
@Override
@@ -414,32 +516,84 @@ public class ScriptTestCase {
return null;
}
+ void verifyDestination(Embedder.Context context) {
+ assertEquals(expectedDestination, context.getDestination());
+ }
+
+ }
+
+ /** An embedder which returns the char value of each letter in the input as a 1d indexed tensor. */
+ private static class MockIndexedEmbedder extends MockEmbedder {
+
+ public MockIndexedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockIndexedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
@Override
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
- assertEquals(expectedDestination, context.getDestination());
+ verifyDestination(context);
var b = Tensor.Builder.of(tensorType);
- if (mappedTensor) {
- for(int i = 0; i < text.length(); i++) {
- var value = text.charAt(i) + addition;
- b.cell().
- label(tensorType.dimensions().get(0).name(), text.charAt(i))
- .value(value);
- }
- } else {
- for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
- b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
+ b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ return b.build();
+ }
- }
+ }
+
+ /** An embedder which returns the char value of each letter in the input as a 1d mapped tensor. */
+ private static class MockMappedEmbedder extends MockEmbedder {
+
+ public MockMappedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockMappedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ verifyDestination(context);
+ var b = Tensor.Builder.of(tensorType);
+ for (int i = 0; i < text.length(); i++)
+ b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition);
return b.build();
}
+
}
- private void assertThrows(Runnable r, String msg) {
- try {
- r.run();
- fail();
- } catch (IllegalStateException e) {
- assertEquals(e.getMessage(), msg);
+ /**
+ * An embedder which returns the char value of each letter in the input as a 2d mixed tensor where each input
+ * char becomes an indexed dimension containing input-1, input, input+1.
+ */
+ private static class MockMixedEmbedder extends MockEmbedder {
+
+ public MockMixedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockMixedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ verifyDestination(context);
+ var b = Tensor.Builder.of(tensorType);
+ String mappedDimension = tensorType.mappedSubtype().dimensions().get(0).name();
+ String indexedDimension = tensorType.indexedSubtype().dimensions().get(0).name();
+ for (int i = 0; i < text.length(); i++) {
+ for (int j = 0; j < 3; j++) {
+ b.cell().label(mappedDimension, i)
+ .label(indexedDimension, j)
+ .value(text.charAt(i) + addition + j - 1);
+ }
+ }
+ return b.build();
}
}