summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/src/main/java/com/yahoo/document/datatypes/Array.java3
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java68
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java34
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java18
7 files changed, 133 insertions, 31 deletions
diff --git a/document/src/main/java/com/yahoo/document/datatypes/Array.java b/document/src/main/java/com/yahoo/document/datatypes/Array.java
index 790cc5d4cde..dba1c0783cf 100644
--- a/document/src/main/java/com/yahoo/document/datatypes/Array.java
+++ b/document/src/main/java/com/yahoo/document/datatypes/Array.java
@@ -140,9 +140,8 @@ public final class Array<T extends FieldValue> extends CollectionFieldValue<T> i
@Override
public boolean equals(Object o) {
if (this == o) return true;
- if (!(o instanceof Array)) return false;
+ if (!(o instanceof Array a)) return false;
if (!super.equals(o)) return false;
- Array a = (Array) o;
if (values.size() != a.values.size()) return false;
if (values instanceof ListWrapper && !(a.values instanceof ListWrapper)) {
return equalsWithListWrapper(a.values, (ListWrapper<? extends FieldValue>) values);
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 2e4bb701454..328cd00742f 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -6,6 +6,7 @@ import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
+import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.process.Embedder;
@@ -13,6 +14,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
@@ -33,7 +35,7 @@ public class EmbedExpression extends Expression {
private TensorType targetType;
public EmbedExpression(Map<String, Embedder> embedders, String embedderId) {
- super(DataType.STRING);
+ super(null);
this.embedderId = embedderId;
boolean embedderIdProvided = embedderId != null && embedderId.length() > 0;
@@ -43,14 +45,14 @@ public class EmbedExpression extends Expression {
}
else if (embedders.size() > 1 && ! embedderIdProvided) {
this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " +
- "Valid embedders are " + validEmbedders(embedders));
+ "Valid embedders are " + validEmbedders(embedders));
}
else if (embedders.size() == 1 && ! embedderIdProvided) {
this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
}
else if ( ! embedders.containsKey(embedderId)) {
this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " +
- "Valid embedders are " + validEmbedders(embedders));
+ "Valid embedders are " + validEmbedders(embedders));
} else {
this.embedder = embedders.get(embedderId);
}
@@ -64,11 +66,48 @@ public class EmbedExpression extends Expression {
@Override
protected void doExecute(ExecutionContext context) {
- StringFieldValue input = (StringFieldValue) context.getValue();
- Tensor tensor = embedder.embed(input.getString(),
- new Embedder.Context(destination).setLanguage(context.getLanguage()),
- targetType);
- context.setValue(new TensorFieldValue(tensor));
+ Tensor output;
+ if (context.getValue().getDataType() == DataType.STRING) {
+ output = embedSingleValue(context);
+ }
+ else if (context.getValue().getDataType() instanceof ArrayDataType &&
+ ((ArrayDataType)context.getValue().getDataType()).getNestedType() == DataType.STRING) {
+ output = embedArrayValue(context);
+ }
+ else {
+ throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " +
+ context.getValue().getDataType());
+ }
+ context.setValue(new TensorFieldValue(output));
+ }
+
+ private Tensor embedSingleValue(ExecutionContext context) {
+ StringFieldValue input = (StringFieldValue)context.getValue();
+ return embed(input.getString(), targetType, context);
+ }
+
+ @SuppressWarnings("unchecked")
+ private Tensor embedArrayValue(ExecutionContext context) {
+ var input = (Array<StringFieldValue>)context.getValue();
+ var builder = Tensor.Builder.of(targetType);
+ for (int i = 0; i < input.size(); i++) {
+ Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context);
+ for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
+ Tensor.Cell cell = cells.next();
+ builder.cell()
+ .label(targetType.mappedSubtype().dimensions().get(0).name(), i)
+ .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().label(0))
+ .value(cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
+ return embedder.embed(input,
+ new Embedder.Context(destination).setLanguage(context.getLanguage()),
+ targetType);
+
}
@Override
@@ -78,6 +117,9 @@ public class EmbedExpression extends Expression {
throw new VerificationException(this, "No output field in this statement: " +
"Don't know what tensor type to embed into.");
targetType = toTargetTensor(context.getInputType(this, outputField));
+ if ( ! validTarget(targetType))
+ throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " +
+ "an array of dense 1d tensors, or a mixed 2d tensor");
context.setValueType(createdOutputType());
}
@@ -94,6 +136,14 @@ public class EmbedExpression extends Expression {
}
+ private boolean validTarget(TensorType target) {
+ if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1)
+ return true;
+ if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1)
+ return true;
+ return false;
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
@@ -105,7 +155,7 @@ public class EmbedExpression extends Expression {
}
@Override
- public int hashCode() { return 1; }
+ public int hashCode() { return 98857339; }
@Override
public boolean equals(Object o) { return o instanceof EmbedExpression; }
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index e6744c010f4..c446c04065a 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -268,6 +268,40 @@ public class ScriptTestCase {
assertEquals(Tensor.from(tensorType, "[115, 101, 99, 111]"), tensorArray.get(1).getTensor().get());
}
+ @Test
+ public void testArrayEmbedToSparseTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
+ var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("second"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ assertEquals(Tensor.from(tensorType, "{ '0':[102, 105, 114, 115], '1':[115, 101, 99, 111]}"),
+ sparseTensor.getTensor().get());
+ }
+
// An embedder which returns the char value of each letter in the input. */
private static class MockEmbedder implements Embedder {
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 7f4a19b029d..418f3ed5911 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -976,6 +976,7 @@
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
"public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])",
"public com.yahoo.tensor.MixedTensor build()",
+ "public static com.yahoo.tensor.MixedTensor$BoundBuilder of(com.yahoo.tensor.TensorType)",
"public bridge synthetic com.yahoo.tensor.Tensor build()"
],
"fields" : [ ]
@@ -1026,6 +1027,7 @@
"public com.yahoo.tensor.MixedTensor build()",
"public void trackBounds(com.yahoo.tensor.TensorAddress)",
"public com.yahoo.tensor.TensorType createBoundType()",
+ "public static com.yahoo.tensor.MixedTensor$UnboundBuilder of(com.yahoo.tensor.TensorType)",
"public bridge synthetic com.yahoo.tensor.Tensor build()"
],
"fields" : [ ]
@@ -1466,6 +1468,7 @@
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
"public com.yahoo.tensor.TensorType$Value valueType()",
"public com.yahoo.tensor.TensorType mappedSubtype()",
+ "public com.yahoo.tensor.TensorType indexedSubtype()",
"public int rank()",
"public java.util.List dimensions()",
"public java.util.Set dimensionNames()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 2027dcfb60f..33e83c00e74 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -305,6 +305,10 @@ public class MixedTensor implements Tensor {
return new MixedTensor(type, builder, indexBuilder.build());
}
+ public static BoundBuilder of(TensorType type) {
+ return new BoundBuilder(type);
+ }
+
}
/**
@@ -371,6 +375,10 @@ public class MixedTensor implements Tensor {
return typeBuilder.build();
}
+ public static UnboundBuilder of(TensorType type) {
+ return new UnboundBuilder(type);
+ }
+
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 5636150bca1..d5c3b1340f1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -13,21 +13,8 @@ import java.util.stream.Collectors;
* @author bratseth
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
- private static String [] createSmallIndexesAsStrings(int count) {
- String [] asStrings = new String[count];
- for (int i = 0; i < count; i++) {
- asStrings[i] = String.valueOf(i);
- }
- return asStrings;
- }
- private static String asString(int index) {
- return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[index] : String.valueOf(index);
- }
- private static String asString(long index) {
- return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
- }
+ private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
public static TensorAddress of(String[] labels) {
return new StringTensorAddress(labels);
@@ -86,8 +73,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
@Override
public boolean equals(Object o) {
if (o == this) return true;
- if ( ! (o instanceof TensorAddress)) return false;
- TensorAddress other = (TensorAddress)o;
+ if ( ! (o instanceof TensorAddress other)) return false;
if (other.size() != this.size()) return false;
for (int i = 0; i < this.size(); i++)
if ( ! Objects.equals(this.label(i), other.label(i)))
@@ -115,6 +101,18 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return "'" + label + "'";
}
+ private static String[] createSmallIndexesAsStrings(int count) {
+ String [] asStrings = new String[count];
+ for (int i = 0; i < count; i++) {
+ asStrings[i] = String.valueOf(i);
+ }
+ return asStrings;
+ }
+
+ private static String asString(long index) {
+ return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
+ }
+
private static final class StringTensorAddress extends TensorAddress {
private final String[] labels;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 36693280183..57d276f278e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -88,6 +88,7 @@ public class TensorType {
private final List<Dimension> dimensions;
private final TensorType mappedSubtype;
+ private final TensorType indexedSubtype;
public TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
@@ -95,12 +96,18 @@ public class TensorType {
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
- if (dimensionList.stream().allMatch(d -> d.isIndexed()))
+ if (dimensionList.stream().allMatch(d -> d.isIndexed())) {
mappedSubtype = empty;
- else if (dimensionList.stream().noneMatch(d -> d.isIndexed()))
+ indexedSubtype = this;
+ }
+ else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) {
mappedSubtype = this;
- else
- mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).toList());
+ indexedSubtype = empty;
+ }
+ else {
+ mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList());
+ indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList());
+ }
}
static public Value combinedValueType(TensorType ... types) {
@@ -135,6 +142,9 @@ public class TensorType {
/** The type representing the mapped subset of dimensions of this. */
public TensorType mappedSubtype() { return mappedSubtype; }
+ /** The type representing the indexed subset of dimensions of this. */
+ public TensorType indexedSubtype() { return indexedSubtype; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }