summaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-02-09 11:13:21 +0100
committerJon Bratseth <bratseth@gmail.com>2022-02-09 11:13:21 +0100
commit17e553c079ca782a1136b35d29f8adc3b7766910 (patch)
tree00c205b49e815ad7e7e25a92723a3b5cf8c435d2 /indexinglanguage
parentb2a8706ceab9b287475705a1be70feab1418f255 (diff)
Type inference where the output type is an array
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java17
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java35
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java19
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java3
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java64
5 files changed, 110 insertions, 28 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 66d912cd987..0da9d907718 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.indexinglanguage.expressions;
+import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
@@ -33,7 +34,7 @@ public class EmbedExpression extends Expression {
@Override
public void setStatementOutput(DocumentType documentType, Field field) {
- targetType = ((TensorDataType)field.getDataType()).getTensorType();
+ targetType = toTargetTensor(field.getDataType());
destination = documentType.getName() + "." + field.getName();
}
@@ -52,11 +53,7 @@ public class EmbedExpression extends Expression {
if (outputField == null)
throw new VerificationException(this, "No output field in this statement: " +
"Don't know what tensor type to embed into.");
- DataType outputFieldType = context.getInputType(this, outputField);
- if ( ! (outputFieldType instanceof TensorDataType) )
- throw new VerificationException(this, "The type of the output field " + outputField +
- " is not a tensor but " + outputField);
- targetType = ((TensorDataType) outputFieldType).getTensorType();
+ targetType = toTargetTensor(context.getInputType(this, outputField));
context.setValueType(createdOutputType());
}
@@ -65,6 +62,14 @@ public class EmbedExpression extends Expression {
return new TensorDataType(targetType);
}
+ private static TensorType toTargetTensor(DataType dataType) {
+ if (dataType instanceof ArrayDataType) return toTargetTensor(((ArrayDataType) dataType).getNestedType());
+ if ( ! ( dataType instanceof TensorDataType))
+ throw new IllegalArgumentException("Expected a tensor data type but got " + dataType);
+ return ((TensorDataType)dataType).getTensorType();
+
+ }
+
@Override
public String toString() { return "embed"; }
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java
index e7c215383aa..ad5ecba8ff4 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java
@@ -52,28 +52,35 @@ public final class ForEachExpression extends CompositeExpression {
@Override
protected void doVerify(VerificationContext context) {
- DataType input = context.getValueType();
- if (input instanceof ArrayDataType || input instanceof WeightedSetDataType) {
- context.setValueType(((CollectionDataType)input).getNestedType()).execute(exp);
- if (input instanceof ArrayDataType) {
+ DataType valueType = context.getValueType();
+ if (valueType instanceof ArrayDataType || valueType instanceof WeightedSetDataType) {
+ // Set type for block evaluation
+ context.setValueType(((CollectionDataType)valueType).getNestedType());
+
+ // Evaluate block, which sets value>Type to the output of the block
+ context.execute(exp);
+
+ // Value type outside block becomes the collection type having the block output type as argument
+ if (valueType instanceof ArrayDataType) {
context.setValueType(DataType.getArray(context.getValueType()));
} else {
- WeightedSetDataType wset = (WeightedSetDataType)input;
+ WeightedSetDataType wset = (WeightedSetDataType)valueType;
context.setValueType(DataType.getWeightedSet(context.getValueType(), wset.createIfNonExistent(), wset.removeIfZero()));
}
- } else if (input instanceof StructDataType) {
- for (Field field : ((StructDataType)input).getFields()) {
+ }
+ else if (valueType instanceof StructDataType) {
+ for (Field field : ((StructDataType)valueType).getFields()) {
DataType fieldType = field.getDataType();
- DataType valueType = context.setValueType(fieldType).execute(exp).getValueType();
- if (!fieldType.isAssignableFrom(valueType)) {
+ DataType structValueType = context.setValueType(fieldType).execute(exp).getValueType();
+ if (!fieldType.isAssignableFrom(structValueType))
throw new VerificationException(this, "Expected " + fieldType.getName() + " output, got " +
- valueType.getName() + ".");
- }
+ structValueType.getName() + ".");
}
- context.setValueType(input);
- } else {
+ context.setValueType(valueType);
+ }
+ else {
throw new VerificationException(this, "Expected Array, Struct or WeightedSet input, got " +
- input.getName() + ".");
+ valueType.getName() + ".");
}
}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java
index 5b04720dad4..dd8aadecb33 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java
@@ -7,6 +7,7 @@ import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
+import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.IntegerFieldValue;
import com.yahoo.document.datatypes.LongFieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
@@ -22,7 +23,7 @@ public class HashExpression extends Expression {
private final HashFunction hasher = Hashing.sipHash24();
- /** The target type we are hashing into. */
+ /** The target *primitive* type we are hashing into. */
private DataType targetType;
public HashExpression() {
@@ -35,8 +36,8 @@ public class HashExpression extends Expression {
throw new IllegalArgumentException("Cannot use the hash function on an indexing statement for " +
field.getName() +
": The hash function can only be used when the target field " +
- "is int or long, not " + field.getDataType());
- targetType = field.getDataType();
+ "is int or long or an array of int or long, not " + field.getDataType());
+ targetType = primitiveTypeOf(field.getDataType());
}
@Override
@@ -68,22 +69,26 @@ public class HashExpression extends Expression {
if ( ! canStoreHash(outputFieldType))
throw new VerificationException(this, "The type of the output field " + outputField +
" is not int or long but " + outputFieldType);
- targetType = outputFieldType;
+ targetType = primitiveTypeOf(outputFieldType);
context.setValueType(createdOutputType());
}
private boolean canStoreHash(DataType type) {
if (type.equals(DataType.INT)) return true;
if (type.equals(DataType.LONG)) return true;
+ if (type instanceof ArrayDataType) return canStoreHash(((ArrayDataType)type).getNestedType());
return false;
}
- @Override
- public DataType createdOutputType() {
- return targetType;
+ private static DataType primitiveTypeOf(DataType type) {
+ if (type instanceof ArrayDataType) return ((ArrayDataType)type).getNestedType();
+ return type;
}
@Override
+ public DataType createdOutputType() { return targetType; }
+
+ @Override
public String toString() { return "hash"; }
@Override
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
index c266a0da430..40aa0f58413 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
@@ -48,7 +48,8 @@ public final class StatementExpression extends ExpressionList<Expression> {
if (expression instanceof OutputExpression)
outputField = ((OutputExpression)expression).getFieldName();
}
- context.setOutputField(outputField);
+ if (outputField != null)
+ context.setOutputField(outputField);
for (Expression expression : this)
context.execute(expression);
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 778d95fcaef..27723c6649d 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -1,12 +1,15 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.indexinglanguage;
+import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.Document;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
+import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.BoolFieldValue;
+import com.yahoo.document.datatypes.IntegerFieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.process.Embedder;
@@ -120,6 +123,34 @@ public class ScriptTestCase {
assertEquals(-1425622096, adapter.values.get("myInt").getWrappedValue());
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testIntArrayHash() throws ParseException {
+ var expression = Expression.fromString("input myTextArray | for_each { hash } | attribute 'myIntArray'");
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ var intField = new Field("myIntArray", new ArrayDataType(DataType.INT));
+ adapter.createField(intField);
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("second"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), intField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(new ArrayDataType(DataType.INT), expression.verify(verificationContext));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("myIntArray"));
+ var intArray = (Array<IntegerFieldValue>)adapter.values.get("myIntArray");
+ assertEquals( 368658787, intArray.get(0).getInteger());
+ assertEquals(-1382874952, intArray.get(1).getInteger());
+ }
+
@Test
public void testLongHash() throws ParseException {
var expression = Expression.fromString("input myText | hash | attribute 'myLong'");
@@ -168,6 +199,39 @@ public class ScriptTestCase {
((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get());
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testArrayEmbed() throws ParseException {
+ TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
+ var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'",
+ new SimpleLinguistics(),
+ new MockEmbedder("myDocument.myTensorArray"));
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+
+ var tensorField = new Field("myTensorArray", new ArrayDataType(new TensorDataType(tensorType)));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("second"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(new ArrayDataType(new TensorDataType(tensorType)), expression.verify(verificationContext));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("myTensorArray"));
+ var tensorArray = (Array<TensorFieldValue>)adapter.values.get("myTensorArray");
+ assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), tensorArray.get(0).getTensor().get());
+ assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), tensorArray.get(1).getTensor().get());
+ }
+
private static class MockEmbedder implements Embedder {
private final String expectedDestination;