diff options
author | Jo Kristian Bergum <bergum@yahoo-inc.com> | 2024-01-04 15:45:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-04 15:45:02 +0100 |
commit | 6d2226e2bf35e32cf618ff12a7e2968c85eabf1f (patch) | |
tree | 5ac11755dbe8998c2b3f4dc5cae8859b0a2e1b9f /indexinglanguage | |
parent | b10d1fd87d7013847b19fc89a620fbe9c7136e61 (diff) | |
parent | 79bb01aa94375b6b9ce464fbdc5db24d1549e7d9 (diff) |
Merge pull request #29667 from vespa-engine/jobergum/splade-embedder
Add SPLADE embedder
Diffstat (limited to 'indexinglanguage')
2 files changed, 66 insertions, 11 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 29399a38fa9..1a9caaa5ca1 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -119,7 +119,7 @@ public class EmbedExpression extends Expression { "Don't know what tensor type to embed into"); targetType = toTargetTensor(context.getInputType(this, outputField)); if ( ! validTarget(targetType)) - throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " + + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," + "an array of dense 1d tensors, or a mixed 2d tensor"); context.setValueType(createdOutputType()); } @@ -134,14 +134,14 @@ public class EmbedExpression extends Expression { if ( ! ( dataType instanceof TensorDataType)) throw new IllegalArgumentException("Expected a tensor data type but got " + dataType); return ((TensorDataType)dataType).getTensorType(); - } private boolean validTarget(TensorType target) { - if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1) - return true; - if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1) + if (target.dimensions().size() == 1) //indexed or mapped 1d tensor return true; + if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 + && target.mappedSubtype().rank() == 1) + return true; //mixed mapped-indexed 2d tensor return false; } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 2b28756a6a8..6206c2efe7a 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -315,7 +315,7 @@ public class ScriptTestCase { } @Test - public void testArrayEmbedToSparseTensor() throws ParseException { + public void testArrayEmbedToMixedTensor() throws ParseException { Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); @@ -348,19 +348,65 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + @SuppressWarnings("OptionalGetWithoutIsPresent") + @Test + public void testEmbedToSparseTensor() throws ParseException { + + Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true); + Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder); + + TensorType tensorType = TensorType.fromSpec("tensor(t{})"); + var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("text", DataType.STRING)); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var text = new StringFieldValue("abc"); + adapter.setValue("text", text); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(text); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"), + sparseTensor.getTensor().get()); + } + // An embedder which returns the char value of each letter in the input. */ private static class MockEmbedder implements Embedder { private final String expectedDestination; private final int addition; + private final boolean mappedTensor; + + public MockEmbedder(String expectedDestination) { - this(expectedDestination, 0); + this(expectedDestination, 0, false); + } + public MockEmbedder(String expectedDestination, boolean mapped) { + this(expectedDestination, 0,mapped); + } + + public MockEmbedder(String expectedDestination,int addition) { + this(expectedDestination, addition,false); } - public MockEmbedder(String expectedDestination, int addition) { + public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) { this.expectedDestination = expectedDestination; this.addition = addition; + this.mappedTensor = mappedTensor; } @Override @@ -372,11 +418,20 @@ public class ScriptTestCase { public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedDestination, context.getDestination()); var b = Tensor.Builder.of(tensorType); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) - b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + if (mappedTensor) { + for(int i = 0; i < text.length(); i++) { + var value = text.charAt(i) + addition; + b.cell(). + label(tensorType.dimensions().get(0).name(), text.charAt(i)) + .value(value); + } + } else { + for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) + b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + + } return b.build(); } - } private void assertThrows(Runnable r, String msg) { |