aboutsummaryrefslogtreecommitdiffstats
path: root/indexinglanguage/src/test/java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
commit35a1ad6eb3d59c9945cdfe8486f57e3f75b3091c (patch)
treeda3ca9d331d3d67060f6d6f9450f06e5be1a411b /indexinglanguage/src/test/java
parent7f923d43611071bf41fcac0c0ccac9eda16bb00c (diff)
Support embedding an array to a mixed 2d tensor
Diffstat (limited to 'indexinglanguage/src/test/java')
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java34
1 files changed, 34 insertions, 0 deletions
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index e6744c010f4..c446c04065a 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -268,6 +268,40 @@ public class ScriptTestCase {
assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get());
}
+ @Test
+ public void testArrayEmbedToSparseTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
+ var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("second"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"),
+ sparseTensor.getTensor().get());
+ }
+
// An embedder which returns the char value of each letter in the input. */
private static class MockEmbedder implements Embedder {