summaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahoo-inc.com>2024-01-04 15:45:02 +0100
committerGitHub <noreply@github.com>2024-01-04 15:45:02 +0100
commit6d2226e2bf35e32cf618ff12a7e2968c85eabf1f (patch)
tree5ac11755dbe8998c2b3f4dc5cae8859b0a2e1b9f /indexinglanguage
parentb10d1fd87d7013847b19fc89a620fbe9c7136e61 (diff)
parent79bb01aa94375b6b9ce464fbdc5db24d1549e7d9 (diff)
Merge pull request #29667 from vespa-engine/jobergum/splade-embedder
Add SPLADE embedder
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java10
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java67
2 files changed, 66 insertions, 11 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 29399a38fa9..1a9caaa5ca1 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -119,7 +119,7 @@ public class EmbedExpression extends Expression {
"Don't know what tensor type to embed into");
targetType = toTargetTensor(context.getInputType(this, outputField));
if ( ! validTarget(targetType))
- throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " +
+ throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," +
"an array of dense 1d tensors, or a mixed 2d tensor");
context.setValueType(createdOutputType());
}
@@ -134,14 +134,14 @@ public class EmbedExpression extends Expression {
if ( ! ( dataType instanceof TensorDataType))
throw new IllegalArgumentException("Expected a tensor data type but got " + dataType);
return ((TensorDataType)dataType).getTensorType();
-
}
private boolean validTarget(TensorType target) {
- if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1)
- return true;
- if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1)
+ if (target.dimensions().size() == 1) //indexed or mapped 1d tensor
return true;
+ if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1
+ && target.mappedSubtype().rank() == 1)
+ return true; //mixed mapped-indexed 2d tensor
return false;
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 2b28756a6a8..6206c2efe7a 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -315,7 +315,7 @@ public class ScriptTestCase {
}
@Test
- public void testArrayEmbedToSparseTensor() throws ParseException {
+ public void testArrayEmbedToMixedTensor() throws ParseException {
Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
@@ -348,19 +348,65 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ @SuppressWarnings("OptionalGetWithoutIsPresent")
+ @Test
+ public void testEmbedToSparseTensor() throws ParseException {
+
+ Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true);
+ Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder);
+
+ TensorType tensorType = TensorType.fromSpec("tensor(t{})");
+ var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("text", DataType.STRING));
+
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var text = new StringFieldValue("abc");
+ adapter.setValue("text", text);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(text);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"),
+ sparseTensor.getTensor().get());
+ }
+
// An embedder which returns the char value of each letter in the input. */
private static class MockEmbedder implements Embedder {
private final String expectedDestination;
private final int addition;
+ private final boolean mappedTensor;
+
+
public MockEmbedder(String expectedDestination) {
- this(expectedDestination, 0);
+ this(expectedDestination, 0, false);
+ }
+ public MockEmbedder(String expectedDestination, boolean mapped) {
+ this(expectedDestination, 0,mapped);
+ }
+
+ public MockEmbedder(String expectedDestination,int addition) {
+ this(expectedDestination, addition,false);
}
- public MockEmbedder(String expectedDestination, int addition) {
+ public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) {
this.expectedDestination = expectedDestination;
this.addition = addition;
+ this.mappedTensor = mappedTensor;
}
@Override
@@ -372,11 +418,20 @@ public class ScriptTestCase {
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
assertEquals(expectedDestination, context.getDestination());
var b = Tensor.Builder.of(tensorType);
- for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
- b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ if (mappedTensor) {
+ for(int i = 0; i < text.length(); i++) {
+ var value = text.charAt(i) + addition;
+ b.cell().
+ label(tensorType.dimensions().get(0).name(), text.charAt(i))
+ .value(value);
+ }
+ } else {
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
+ b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+
+ }
return b.build();
}
-
}
private void assertThrows(Runnable r, String msg) {