aboutsummaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
commit35a1ad6eb3d59c9945cdfe8486f57e3f75b3091c (patch)
treeda3ca9d331d3d67060f6d6f9450f06e5be1a411b /indexinglanguage
parent7f923d43611071bf41fcac0c0ccac9eda16bb00c (diff)
Support embedding an array to a mixed 2d tensor
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java68
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java34
2 files changed, 93 insertions, 9 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 2e4bb701454..328cd00742f 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -6,6 +6,7 @@ import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
+import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.process.Embedder;
@@ -13,6 +14,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
@@ -33,7 +35,7 @@ public class EmbedExpression extends Expression {
private TensorType targetType;
public EmbedExpression(Map<String, Embedder> embedders, String embedderId) {
- super(DataType.STRING);
+ super(null);
this.embedderId = embedderId;
boolean embedderIdProvided = embedderId != null && embedderId.length() > 0;
@@ -43,14 +45,14 @@ public class EmbedExpression extends Expression {
}
else if (embedders.size() > 1 && ! embedderIdProvided) {
this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " +
- "Valid embedders are " + validEmbedders(embedders));
+ "Valid embedders are " + validEmbedders(embedders));
}
else if (embedders.size() == 1 && ! embedderIdProvided) {
this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
}
else if ( ! embedders.containsKey(embedderId)) {
this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " +
- "Valid embedders are " + validEmbedders(embedders));
+ "Valid embedders are " + validEmbedders(embedders));
} else {
this.embedder = embedders.get(embedderId);
}
@@ -64,11 +66,48 @@ public class EmbedExpression extends Expression {
@Override
protected void doExecute(ExecutionContext context) {
- StringFieldValue input = (StringFieldValue) context.getValue();
- Tensor tensor = embedder.embed(input.getString(),
- new Embedder.Context(destination).setLanguage(context.getLanguage()),
- targetType);
- context.setValue(new TensorFieldValue(tensor));
+ Tensor output;
+ if (context.getValue().getDataType() == DataType.STRING) {
+ output = embedSingleValue(context);
+ }
+ else if (context.getValue().getDataType() instanceof ArrayDataType &&
+ ((ArrayDataType)context.getValue().getDataType()).getNestedType() == DataType.STRING) {
+ output = embedArrayValue(context);
+ }
+ else {
+ throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " +
+ context.getValue().getDataType());
+ }
+ context.setValue(new TensorFieldValue(output));
+ }
+
+ private Tensor embedSingleValue(ExecutionContext context) {
+ StringFieldValue input = (StringFieldValue)context.getValue();
+ return embed(input.getString(), targetType, context);
+ }
+
+ @SuppressWarnings("unchecked")
+ private Tensor embedArrayValue(ExecutionContext context) {
+ var input = (Array<StringFieldValue>)context.getValue();
+ var builder = Tensor.Builder.of(targetType);
+ for (int i = 0; i < input.size(); i++) {
+ Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context);
+ for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
+ Tensor.Cell cell = cells.next();
+ builder.cell()
+ .label(targetType.mappedSubtype().dimensions().get(0).name(), i)
+ .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().label(0))
+ .value(cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
+ return embedder.embed(input,
+ new Embedder.Context(destination).setLanguage(context.getLanguage()),
+ targetType);
+
}
@Override
@@ -78,6 +117,9 @@ public class EmbedExpression extends Expression {
throw new VerificationException(this, "No output field in this statement: " +
"Don't know what tensor type to embed into.");
targetType = toTargetTensor(context.getInputType(this, outputField));
+ if ( ! validTarget(targetType))
+ throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " +
+ "an array of dense 1d tensors, or a mixed 2d tensor");
context.setValueType(createdOutputType());
}
@@ -94,6 +136,14 @@ public class EmbedExpression extends Expression {
}
+ private boolean validTarget(TensorType target) {
+ if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1)
+ return true;
+ if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1)
+ return true;
+ return false;
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
@@ -105,7 +155,7 @@ public class EmbedExpression extends Expression {
}
@Override
- public int hashCode() { return 1; }
+ public int hashCode() { return 98857339; }
@Override
public boolean equals(Object o) { return o instanceof EmbedExpression; }
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index e6744c010f4..c446c04065a 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -268,6 +268,40 @@ public class ScriptTestCase {
assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get());
}
+ @Test
+ public void testArrayEmbedToSparseTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
+ var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("second"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"),
+ sparseTensor.getTensor().get());
+ }
+
// An embedder which returns the char value of each letter in the input. */
private static class MockEmbedder implements Embedder {