diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-19 09:52:52 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-19 09:52:52 +0100 |
commit | 79bb01aa94375b6b9ce464fbdc5db24d1549e7d9 (patch) | |
tree | 0302c2c3b1a5e59fd6e6626381f6c0bb5722712a /indexinglanguage | |
parent | 745a8db7a8eaea7aa53736a26d64e97543900343 (diff) |
Add test coverage of mapped tensor in indexing embed
Diffstat (limited to 'indexinglanguage')
-rw-r--r-- | indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java | 67 |
1 files changed, 61 insertions, 6 deletions
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 2b28756a6a8..6206c2efe7a 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -315,7 +315,7 @@ public class ScriptTestCase { } @Test - public void testArrayEmbedToSparseTensor() throws ParseException { + public void testArrayEmbedToMixedTensor() throws ParseException { Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); @@ -348,19 +348,65 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + @SuppressWarnings("OptionalGetWithoutIsPresent") + @Test + public void testEmbedToSparseTensor() throws ParseException { + + Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true); + Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder); + + TensorType tensorType = TensorType.fromSpec("tensor(t{})"); + var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("text", DataType.STRING)); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var text = new StringFieldValue("abc"); + adapter.setValue("text", text); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(text); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"), + sparseTensor.getTensor().get()); + } + // An embedder which returns the char value of each letter in the input. */ private static class MockEmbedder implements Embedder { private final String expectedDestination; private final int addition; + private final boolean mappedTensor; + + public MockEmbedder(String expectedDestination) { - this(expectedDestination, 0); + this(expectedDestination, 0, false); + } + public MockEmbedder(String expectedDestination, boolean mapped) { + this(expectedDestination, 0,mapped); + } + + public MockEmbedder(String expectedDestination,int addition) { + this(expectedDestination, addition,false); } - public MockEmbedder(String expectedDestination, int addition) { + public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) { this.expectedDestination = expectedDestination; this.addition = addition; + this.mappedTensor = mappedTensor; } @Override @@ -372,11 +418,20 @@ public class ScriptTestCase { public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedDestination, context.getDestination()); var b = Tensor.Builder.of(tensorType); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) - b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + if (mappedTensor) { + for(int i = 0; i < text.length(); i++) { + var value = text.charAt(i) + addition; + b.cell(). + label(tensorType.dimensions().get(0).name(), text.charAt(i)) + .value(value); + } + } else { + for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) + b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + + } return b.build(); } - } private void assertThrows(Runnable r, String msg) { |