diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2023-01-27 10:01:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-27 10:01:07 +0100 |
commit | a31e9abd429b252e9bf67c4e31e2f2cb43c7f9e4 (patch) | |
tree | da3ca9d331d3d67060f6d6f9450f06e5be1a411b | |
parent | 7f923d43611071bf41fcac0c0ccac9eda16bb00c (diff) | |
parent | 35a1ad6eb3d59c9945cdfe8486f57e3f75b3091c (diff) |
Merge pull request #25761 from vespa-engine/bratseth/embed-to-multivalue
Support embedding an array to a mixed 2d tensor
7 files changed, 133 insertions, 31 deletions
diff --git a/document/src/main/java/com/yahoo/document/datatypes/Array.java b/document/src/main/java/com/yahoo/document/datatypes/Array.java index 790cc5d4cde..dba1c0783cf 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/Array.java +++ b/document/src/main/java/com/yahoo/document/datatypes/Array.java @@ -140,9 +140,8 @@ public final class Array<T extends FieldValue> extends CollectionFieldValue<T> i @Override public boolean equals(Object o) { if (this == o) return true; - if (!(o instanceof Array)) return false; + if (!(o instanceof Array a)) return false; if (!super.equals(o)) return false; - Array a = (Array) o; if (values.size() != a.values.size()) return false; if (values instanceof ListWrapper && !(a.values instanceof ListWrapper)) { return equalsWithListWrapper(a.values, (ListWrapper<? extends FieldValue>) values); diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 2e4bb701454..328cd00742f 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -6,6 +6,7 @@ import com.yahoo.document.DataType; import com.yahoo.document.DocumentType; import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.Array; import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.language.process.Embedder; @@ -13,6 +14,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.Map; @@ -33,7 +35,7 @@ public class EmbedExpression extends Expression { private TensorType targetType; public EmbedExpression(Map<String, Embedder> embedders, String embedderId) { - super(DataType.STRING); + super(null); this.embedderId = embedderId; boolean embedderIdProvided = embedderId != null && embedderId.length() > 0; @@ -43,14 +45,14 @@ public class EmbedExpression extends Expression { } else if (embedders.size() > 1 && ! embedderIdProvided) { this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " + - "Valid embedders are " + validEmbedders(embedders)); + "Valid embedders are " + validEmbedders(embedders)); } else if (embedders.size() == 1 && ! embedderIdProvided) { this.embedder = embedders.entrySet().stream().findFirst().get().getValue(); } else if ( ! embedders.containsKey(embedderId)) { this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); + "Valid embedders are " + validEmbedders(embedders)); } else { this.embedder = embedders.get(embedderId); } @@ -64,11 +66,48 @@ public class EmbedExpression extends Expression { @Override protected void doExecute(ExecutionContext context) { - StringFieldValue input = (StringFieldValue) context.getValue(); - Tensor tensor = embedder.embed(input.getString(), - new Embedder.Context(destination).setLanguage(context.getLanguage()), - targetType); - context.setValue(new TensorFieldValue(tensor)); + Tensor output; + if (context.getValue().getDataType() == DataType.STRING) { + output = embedSingleValue(context); + } + else if (context.getValue().getDataType() instanceof ArrayDataType && + ((ArrayDataType)context.getValue().getDataType()).getNestedType() == DataType.STRING) { + output = embedArrayValue(context); + } + else { + throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + + context.getValue().getDataType()); + } + context.setValue(new TensorFieldValue(output)); + } + + private Tensor embedSingleValue(ExecutionContext context) { + StringFieldValue input = (StringFieldValue)context.getValue(); + return embed(input.getString(), targetType, context); + } + + @SuppressWarnings("unchecked") + private Tensor embedArrayValue(ExecutionContext context) { + var input = (Array<StringFieldValue>)context.getValue(); + var builder = Tensor.Builder.of(targetType); + for (int i = 0; i < input.size(); i++) { + Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context); + for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) { + Tensor.Cell cell = cells.next(); + builder.cell() + .label(targetType.mappedSubtype().dimensions().get(0).name(), i) + .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().label(0)) + .value(cell.getValue()); + } + } + return builder.build(); + } + + private Tensor embed(String input, TensorType targetType, ExecutionContext context) { + return embedder.embed(input, + new Embedder.Context(destination).setLanguage(context.getLanguage()), + targetType); + } @Override @@ -78,6 +117,9 @@ public class EmbedExpression extends Expression { throw new VerificationException(this, "No output field in this statement: " + "Don't know what tensor type to embed into."); targetType = toTargetTensor(context.getInputType(this, outputField)); + if ( ! validTarget(targetType)) + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " + + "an array of dense 1d tensors, or a mixed 2d tensor"); context.setValueType(createdOutputType()); } @@ -94,6 +136,14 @@ public class EmbedExpression extends Expression { } + private boolean validTarget(TensorType target) { + if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1) + return true; + if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1) + return true; + return false; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -105,7 +155,7 @@ public class EmbedExpression extends Expression { } @Override - public int hashCode() { return 1; } + public int hashCode() { return 98857339; } @Override public boolean equals(Object o) { return o instanceof EmbedExpression; } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index e6744c010f4..c446c04065a 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -268,6 +268,40 @@ public class ScriptTestCase { assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get()); } + @Test + public void testArrayEmbedToSparseTensor() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); + + TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); + var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("second")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"), + sparseTensor.getTensor().get()); + } + // An embedder which returns the char value of each letter in the input. */ private static class MockEmbedder implements Embedder { diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 7f4a19b029d..418f3ed5911 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -976,6 +976,7 @@ "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", "public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])", "public com.yahoo.tensor.MixedTensor build()", + "public static com.yahoo.tensor.MixedTensor$BoundBuilder of(com.yahoo.tensor.TensorType)", "public bridge synthetic com.yahoo.tensor.Tensor build()" ], "fields" : [ ] @@ -1026,6 +1027,7 @@ "public com.yahoo.tensor.MixedTensor build()", "public void trackBounds(com.yahoo.tensor.TensorAddress)", "public com.yahoo.tensor.TensorType createBoundType()", + "public static com.yahoo.tensor.MixedTensor$UnboundBuilder of(com.yahoo.tensor.TensorType)", "public bridge synthetic com.yahoo.tensor.Tensor build()" ], "fields" : [ ] @@ -1466,6 +1468,7 @@ "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", "public com.yahoo.tensor.TensorType mappedSubtype()", + "public com.yahoo.tensor.TensorType indexedSubtype()", "public int rank()", "public java.util.List dimensions()", "public java.util.Set dimensionNames()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 2027dcfb60f..33e83c00e74 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -305,6 +305,10 @@ public class MixedTensor implements Tensor { return new MixedTensor(type, builder, indexBuilder.build()); } + public static BoundBuilder of(TensorType type) { + return new BoundBuilder(type); + } + } /** @@ -371,6 +375,10 @@ public class MixedTensor implements Tensor { return typeBuilder.build(); } + public static UnboundBuilder of(TensorType type) { + return new UnboundBuilder(type); + } + } /** diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 5636150bca1..d5c3b1340f1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -13,21 +13,8 @@ import java.util.stream.Collectors; * @author bratseth */ public abstract class TensorAddress implements Comparable<TensorAddress> { - private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); - private static String [] createSmallIndexesAsStrings(int count) { - String [] asStrings = new String[count]; - for (int i = 0; i < count; i++) { - asStrings[i] = String.valueOf(i); - } - return asStrings; - } - private static String asString(int index) { - return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[index] : String.valueOf(index); - } - private static String asString(long index) { - return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index); - } + private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); public static TensorAddress of(String[] labels) { return new StringTensorAddress(labels); @@ -86,8 +73,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { @Override public boolean equals(Object o) { if (o == this) return true; - if ( ! (o instanceof TensorAddress)) return false; - TensorAddress other = (TensorAddress)o; + if ( ! (o instanceof TensorAddress other)) return false; if (other.size() != this.size()) return false; for (int i = 0; i < this.size(); i++) if ( ! Objects.equals(this.label(i), other.label(i))) @@ -115,6 +101,18 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return "'" + label + "'"; } + private static String[] createSmallIndexesAsStrings(int count) { + String [] asStrings = new String[count]; + for (int i = 0; i < count; i++) { + asStrings[i] = String.valueOf(i); + } + return asStrings; + } + + private static String asString(long index) { + return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } + private static final class StringTensorAddress extends TensorAddress { private final String[] labels; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 36693280183..57d276f278e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -88,6 +88,7 @@ public class TensorType { private final List<Dimension> dimensions; private final TensorType mappedSubtype; + private final TensorType indexedSubtype; public TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; @@ -95,12 +96,18 @@ public class TensorType { Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); - if (dimensionList.stream().allMatch(d -> d.isIndexed())) + if (dimensionList.stream().allMatch(d -> d.isIndexed())) { mappedSubtype = empty; - else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) + indexedSubtype = this; + } + else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) { mappedSubtype = this; - else - mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).toList()); + indexedSubtype = empty; + } + else { + mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList()); + indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList()); + } } static public Value combinedValueType(TensorType ... types) { @@ -135,6 +142,9 @@ public class TensorType { /** The type representing the mapped subset of dimensions of this. */ public TensorType mappedSubtype() { return mappedSubtype; } + /** The type representing the indexed subset of dimensions of this. */ + public TensorType indexedSubtype() { return indexedSubtype; } + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } |