summaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-12-19 09:52:52 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2023-12-19 09:52:52 +0100
commit79bb01aa94375b6b9ce464fbdc5db24d1549e7d9 (patch)
tree0302c2c3b1a5e59fd6e6626381f6c0bb5722712a /indexinglanguage
parent745a8db7a8eaea7aa53736a26d64e97543900343 (diff)
Add test coverage of mapped tensor in indexing embed
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java67
1 files changed, 61 insertions, 6 deletions
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 2b28756a6a8..6206c2efe7a 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -315,7 +315,7 @@ public class ScriptTestCase {
}
@Test
- public void testArrayEmbedToSparseTensor() throws ParseException {
+ public void testArrayEmbedToMixedTensor() throws ParseException {
Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
@@ -348,19 +348,65 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ @SuppressWarnings("OptionalGetWithoutIsPresent")
+ @Test
+ public void testEmbedToSparseTensor() throws ParseException {
+
+ Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true);
+ Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder);
+
+ TensorType tensorType = TensorType.fromSpec("tensor(t{})");
+ var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("text", DataType.STRING));
+
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var text = new StringFieldValue("abc");
+ adapter.setValue("text", text);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(text);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"),
+ sparseTensor.getTensor().get());
+ }
+
// An embedder which returns the char value of each letter in the input. */
private static class MockEmbedder implements Embedder {
private final String expectedDestination;
private final int addition;
+ private final boolean mappedTensor;
+
+
public MockEmbedder(String expectedDestination) {
- this(expectedDestination, 0);
+ this(expectedDestination, 0, false);
+ }
+ public MockEmbedder(String expectedDestination, boolean mapped) {
+ this(expectedDestination, 0,mapped);
+ }
+
+ public MockEmbedder(String expectedDestination,int addition) {
+ this(expectedDestination, addition,false);
}
- public MockEmbedder(String expectedDestination, int addition) {
+ public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) {
this.expectedDestination = expectedDestination;
this.addition = addition;
+ this.mappedTensor = mappedTensor;
}
@Override
@@ -372,11 +418,20 @@ public class ScriptTestCase {
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
assertEquals(expectedDestination, context.getDestination());
var b = Tensor.Builder.of(tensorType);
- for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
- b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ if (mappedTensor) {
+ for(int i = 0; i < text.length(); i++) {
+ var value = text.charAt(i) + addition;
+ b.cell().
+ label(tensorType.dimensions().get(0).name(), text.charAt(i))
+ .value(value);
+ }
+ } else {
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
+ b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+
+ }
return b.build();
}
-
}
private void assertThrows(Runnable r, String msg) {