diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-01-09 17:00:35 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2023-01-09 17:00:35 +0100 |
commit | 17d534d070b89a4cc7f908e945c974f8f31c5e12 (patch) | |
tree | c154f778b630826fc5a5903b49331e36a81b8f60 /indexinglanguage/src/test/java | |
parent | baf54beef9119768f99d41d950373f742a42df62 (diff) |
Improve test
Diffstat (limited to 'indexinglanguage/src/test/java')
-rw-r--r-- | indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java | 60 |
1 files changed, 37 insertions, 23 deletions
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index de31f6fcb1e..3a32b0049fe 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -181,35 +181,41 @@ public class ScriptTestCase { Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); Map<String, Embedder> embedder = Map.of( - "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]") + "emb1", new MockEmbedder("myDocument.myTensor") ); - testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "[1,2,0,0]"); - testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, "[1,2,0,0]"); - testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, "[1,2,0,0]"); + testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, + "input text", "[105, 110, 112, 117]"); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, + "input text", "[105, 110, 112, 117]"); + testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, + "input text", "[105, 110, 112, 117]"); Map<String, Embedder> embedders = Map.of( - "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]"), - "emb2", new MockEmbedder("myDocument.myTensor", "[3,4,5,0]") + "emb1", new MockEmbedder("myDocument.myTensor"), + "emb2", new MockEmbedder("myDocument.myTensor", 1) ); - testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "[1,2,0,0]"); - testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, "[3,4,5,0]"); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, + "my input", "[109.0, 121.0, 32.0, 105.0]"); + testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, + "my input", "[110.0, 122.0, 33.0, 106.0]"); - assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "[3,4,5,0]"), + assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1,emb2"); - assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "[3,4,5,0]"), + assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "input text", "[105, 110, 112, 117]"), "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); } - private void testEmbedStatement(String exp, Map<String, Embedder> embedders, String expected) { + private void testEmbedStatement(String expressionString, Map<String, Embedder> embedders, String input, String expected) { try { - var expression = Expression.fromString(exp, new SimpleLinguistics(), embedders); + var expression = Expression.fromString(expressionString, new SimpleLinguistics(), embedders); TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myText", DataType.STRING)); var tensorField = new Field("myTensor", new TensorDataType(tensorType)); adapter.createField(tensorField); - adapter.setValue("myText", new StringFieldValue("input text")); + if (input != null) + adapter.setValue("myText", new StringFieldValue(input)); expression.setStatementOutput(new DocumentType("myDocument"), tensorField); // Necessary to resolve output type @@ -217,12 +223,12 @@ public class ScriptTestCase { assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); ExecutionContext context = new ExecutionContext(adapter); - context.setValue(new StringFieldValue("input text")); expression.execute(context); assertTrue(adapter.values.containsKey("myTensor")); assertEquals(Tensor.from(tensorType, expected), - ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); - } catch (ParseException e) { + ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); + } + catch (ParseException e) { throw new IllegalArgumentException(e); } } @@ -230,7 +236,7 @@ public class ScriptTestCase { @SuppressWarnings("unchecked") @Test public void testArrayEmbed() throws ParseException { - Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray", "[7,3,0,0]")); + Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray")); TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", @@ -258,18 +264,23 @@ public class ScriptTestCase { expression.execute(context); assertTrue(adapter.values.containsKey("myTensorArray")); var tensorArray = (Array<TensorFieldValue>)adapter.values.get("myTensorArray"); - assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), tensorArray.get(0).getTensor().get()); - assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), tensorArray.get(1).getTensor().get()); + assertEquals(Tensor.from(tensorType, "[102, 105, 114, 115]"), tensorArray.get(0).getTensor().get()); + assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get()); } + // An embedder which returns the char value of each letter in the input. */ private static class MockEmbedder implements Embedder { private final String expectedDestination; - private final String tensorString; + private final int addition; + + public MockEmbedder(String expectedDestination) { + this(expectedDestination, 0); + } - public MockEmbedder(String expectedDestination, String tensorString) { + public MockEmbedder(String expectedDestination, int addition) { this.expectedDestination = expectedDestination; - this.tensorString = tensorString; + this.addition = addition; } @Override @@ -280,7 +291,10 @@ public class ScriptTestCase { @Override public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedDestination, context.getDestination()); - return Tensor.from(tensorType, tensorString); + var b = Tensor.Builder.of(tensorType); + for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) + b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + return b.build(); } } |