diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-01-27 09:38:43 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2023-01-27 09:38:43 +0100 |
commit | 35a1ad6eb3d59c9945cdfe8486f57e3f75b3091c (patch) | |
tree | da3ca9d331d3d67060f6d6f9450f06e5be1a411b /indexinglanguage | |
parent | 7f923d43611071bf41fcac0c0ccac9eda16bb00c (diff) |
Support embedding an array to a mixed 2d tensor
Diffstat (limited to 'indexinglanguage')
2 files changed, 93 insertions, 9 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 2e4bb701454..328cd00742f 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -6,6 +6,7 @@ import com.yahoo.document.DataType; import com.yahoo.document.DocumentType; import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.Array; import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.language.process.Embedder; @@ -13,6 +14,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Map; @@ -33,7 +35,7 @@ public class EmbedExpression extends Expression { private TensorType targetType; public EmbedExpression(Map<String, Embedder> embedders, String embedderId) { - super(DataType.STRING); + super(null); this.embedderId = embedderId; boolean embedderIdProvided = embedderId != null && embedderId.length() > 0; @@ -43,14 +45,14 @@ public class EmbedExpression extends Expression { } else if (embedders.size() > 1 && ! embedderIdProvided) { this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " + - "Valid embedders are " + validEmbedders(embedders)); + "Valid embedders are " + validEmbedders(embedders)); } else if (embedders.size() == 1 && ! embedderIdProvided) { this.embedder = embedders.entrySet().stream().findFirst().get().getValue(); } else if ( ! embedders.containsKey(embedderId)) { this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); + "Valid embedders are " + validEmbedders(embedders)); } else { this.embedder = embedders.get(embedderId); } @@ -64,11 +66,48 @@ public class EmbedExpression extends Expression { @Override protected void doExecute(ExecutionContext context) { - StringFieldValue input = (StringFieldValue) context.getValue(); - Tensor tensor = embedder.embed(input.getString(), - new Embedder.Context(destination).setLanguage(context.getLanguage()), - targetType); - context.setValue(new TensorFieldValue(tensor)); + Tensor output; + if (context.getValue().getDataType() == DataType.STRING) { + output = embedSingleValue(context); + } + else if (context.getValue().getDataType() instanceof ArrayDataType && + ((ArrayDataType)context.getValue().getDataType()).getNestedType() == DataType.STRING) { + output = embedArrayValue(context); + } + else { + throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + + context.getValue().getDataType()); + } + context.setValue(new TensorFieldValue(output)); + } + + private Tensor embedSingleValue(ExecutionContext context) { + StringFieldValue input = (StringFieldValue)context.getValue(); + return embed(input.getString(), targetType, context); + } + + @SuppressWarnings("unchecked") + private Tensor embedArrayValue(ExecutionContext context) { + var input = (Array<StringFieldValue>)context.getValue(); + var builder = Tensor.Builder.of(targetType); + for (int i = 0; i < input.size(); i++) { + Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context); + for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) { + Tensor.Cell cell = cells.next(); + builder.cell() + .label(targetType.mappedSubtype().dimensions().get(0).name(), i) + .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().label(0)) + .value(cell.getValue()); + } + } + return builder.build(); + } + + private Tensor embed(String input, TensorType targetType, ExecutionContext context) { + return embedder.embed(input, + new Embedder.Context(destination).setLanguage(context.getLanguage()), + targetType); + } @Override @@ -78,6 +117,9 @@ public class EmbedExpression extends Expression { throw new VerificationException(this, "No output field in this statement: " + "Don't know what tensor type to embed into."); targetType = toTargetTensor(context.getInputType(this, outputField)); + if ( ! validTarget(targetType)) + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " + + "an array of dense 1d tensors, or a mixed 2d tensor"); context.setValueType(createdOutputType()); } @@ -94,6 +136,14 @@ public class EmbedExpression extends Expression { } + private boolean validTarget(TensorType target) { + if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1) + return true; + if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1) + return true; + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -105,7 +155,7 @@ public class EmbedExpression extends Expression { } @Override - public int hashCode() { return 1; } + public int hashCode() { return 98857339; } @Override public boolean equals(Object o) { return o instanceof EmbedExpression; } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index e6744c010f4..c446c04065a 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -268,6 +268,40 @@ public class ScriptTestCase { assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get()); } + @Test + public void testArrayEmbedToSparseTensor() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); + var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("second")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"), + sparseTensor.getTensor().get()); + } + // An embedder which returns the char value of each letter in the input. */ private static class MockEmbedder implements Embedder { |