diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-02-09 11:13:21 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-02-09 11:13:21 +0100 |
commit | 17e553c079ca782a1136b35d29f8adc3b7766910 (patch) | |
tree | 00c205b49e815ad7e7e25a92723a3b5cf8c435d2 | |
parent | b2a8706ceab9b287475705a1be70feab1418f255 (diff) |
Type inference where the output type is an array
5 files changed, 110 insertions, 28 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 66d912cd987..0da9d907718 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.indexinglanguage.expressions; +import com.yahoo.document.ArrayDataType; import com.yahoo.document.DataType; import com.yahoo.document.DocumentType; import com.yahoo.document.Field; @@ -33,7 +34,7 @@ public class EmbedExpression extends Expression { @Override public void setStatementOutput(DocumentType documentType, Field field) { - targetType = ((TensorDataType)field.getDataType()).getTensorType(); + targetType = toTargetTensor(field.getDataType()); destination = documentType.getName() + "." + field.getName(); } @@ -52,11 +53,7 @@ public class EmbedExpression extends Expression { if (outputField == null) throw new VerificationException(this, "No output field in this statement: " + "Don't know what tensor type to embed into."); - DataType outputFieldType = context.getInputType(this, outputField); - if ( ! (outputFieldType instanceof TensorDataType) ) - throw new VerificationException(this, "The type of the output field " + outputField + - " is not a tensor but " + outputField); - targetType = ((TensorDataType) outputFieldType).getTensorType(); + targetType = toTargetTensor(context.getInputType(this, outputField)); context.setValueType(createdOutputType()); } @@ -65,6 +62,14 @@ public class EmbedExpression extends Expression { return new TensorDataType(targetType); } + private static TensorType toTargetTensor(DataType dataType) { + if (dataType instanceof ArrayDataType) return toTargetTensor(((ArrayDataType) dataType).getNestedType()); + if ( ! ( dataType instanceof TensorDataType)) + throw new IllegalArgumentException("Expected a tensor data type but got " + dataType); + return ((TensorDataType)dataType).getTensorType(); + + } + @Override public String toString() { return "embed"; } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java index e7c215383aa..ad5ecba8ff4 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ForEachExpression.java @@ -52,28 +52,35 @@ public final class ForEachExpression extends CompositeExpression { @Override protected void doVerify(VerificationContext context) { - DataType input = context.getValueType(); - if (input instanceof ArrayDataType || input instanceof WeightedSetDataType) { - context.setValueType(((CollectionDataType)input).getNestedType()).execute(exp); - if (input instanceof ArrayDataType) { + DataType valueType = context.getValueType(); + if (valueType instanceof ArrayDataType || valueType instanceof WeightedSetDataType) { + // Set type for block evaluation + context.setValueType(((CollectionDataType)valueType).getNestedType()); + + // Evaluate block, which sets value>Type to the output of the block + context.execute(exp); + + // Value type outside block becomes the collection type having the block output type as argument + if (valueType instanceof ArrayDataType) { context.setValueType(DataType.getArray(context.getValueType())); } else { - WeightedSetDataType wset = (WeightedSetDataType)input; + WeightedSetDataType wset = (WeightedSetDataType)valueType; context.setValueType(DataType.getWeightedSet(context.getValueType(), wset.createIfNonExistent(), wset.removeIfZero())); } - } else if (input instanceof StructDataType) { - for (Field field : ((StructDataType)input).getFields()) { + } + else if (valueType instanceof StructDataType) { + for (Field field : ((StructDataType)valueType).getFields()) { DataType fieldType = field.getDataType(); - DataType valueType = context.setValueType(fieldType).execute(exp).getValueType(); - if (!fieldType.isAssignableFrom(valueType)) { + DataType structValueType = context.setValueType(fieldType).execute(exp).getValueType(); + if (!fieldType.isAssignableFrom(structValueType)) throw new VerificationException(this, "Expected " + fieldType.getName() + " output, got " + - valueType.getName() + "."); - } + structValueType.getName() + "."); } - context.setValueType(input); - } else { + context.setValueType(valueType); + } + else { throw new VerificationException(this, "Expected Array, Struct or WeightedSet input, got " + - input.getName() + "."); + valueType.getName() + "."); } } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java index 5b04720dad4..dd8aadecb33 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/HashExpression.java @@ -7,6 +7,7 @@ import com.yahoo.document.ArrayDataType; import com.yahoo.document.DataType; import com.yahoo.document.DocumentType; import com.yahoo.document.Field; +import com.yahoo.document.datatypes.Array; import com.yahoo.document.datatypes.IntegerFieldValue; import com.yahoo.document.datatypes.LongFieldValue; import com.yahoo.document.datatypes.StringFieldValue; @@ -22,7 +23,7 @@ public class HashExpression extends Expression { private final HashFunction hasher = Hashing.sipHash24(); - /** The target type we are hashing into. */ + /** The target *primitive* type we are hashing into. */ private DataType targetType; public HashExpression() { @@ -35,8 +36,8 @@ public class HashExpression extends Expression { throw new IllegalArgumentException("Cannot use the hash function on an indexing statement for " + field.getName() + ": The hash function can only be used when the target field " + - "is int or long, not " + field.getDataType()); - targetType = field.getDataType(); + "is int or long or an array of int or long, not " + field.getDataType()); + targetType = primitiveTypeOf(field.getDataType()); } @Override @@ -68,22 +69,26 @@ public class HashExpression extends Expression { if ( ! canStoreHash(outputFieldType)) throw new VerificationException(this, "The type of the output field " + outputField + " is not int or long but " + outputFieldType); - targetType = outputFieldType; + targetType = primitiveTypeOf(outputFieldType); context.setValueType(createdOutputType()); } private boolean canStoreHash(DataType type) { if (type.equals(DataType.INT)) return true; if (type.equals(DataType.LONG)) return true; + if (type instanceof ArrayDataType) return canStoreHash(((ArrayDataType)type).getNestedType()); return false; } - @Override - public DataType createdOutputType() { - return targetType; + private static DataType primitiveTypeOf(DataType type) { + if (type instanceof ArrayDataType) return ((ArrayDataType)type).getNestedType(); + return type; } @Override + public DataType createdOutputType() { return targetType; } + + @Override public String toString() { return "hash"; } @Override diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java index c266a0da430..40aa0f58413 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java @@ -48,7 +48,8 @@ public final class StatementExpression extends ExpressionList<Expression> { if (expression instanceof OutputExpression) outputField = ((OutputExpression)expression).getFieldName(); } - context.setOutputField(outputField); + if (outputField != null) + context.setOutputField(outputField); for (Expression expression : this) context.execute(expression); } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 778d95fcaef..27723c6649d 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -1,12 +1,15 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.indexinglanguage; +import com.yahoo.document.ArrayDataType; import com.yahoo.document.DataType; import com.yahoo.document.Document; import com.yahoo.document.DocumentType; import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.Array; import com.yahoo.document.datatypes.BoolFieldValue; +import com.yahoo.document.datatypes.IntegerFieldValue; import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.language.process.Embedder; @@ -120,6 +123,34 @@ public class ScriptTestCase { assertEquals(-1425622096, adapter.values.get("myInt").getWrappedValue()); } + @SuppressWarnings("unchecked") + @Test + public void testIntArrayHash() throws ParseException { + var expression = Expression.fromString("input myTextArray | for_each { hash } | attribute 'myIntArray'"); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + var intField = new Field("myIntArray", new ArrayDataType(DataType.INT)); + adapter.createField(intField); + var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("second")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), intField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new ArrayDataType(DataType.INT), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("myIntArray")); + var intArray = (Array<IntegerFieldValue>)adapter.values.get("myIntArray"); + assertEquals( 368658787, intArray.get(0).getInteger()); + assertEquals(-1382874952, intArray.get(1).getInteger()); + } + @Test public void testLongHash() throws ParseException { var expression = Expression.fromString("input myText | hash | attribute 'myLong'"); @@ -168,6 +199,39 @@ public class ScriptTestCase { ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); } + @SuppressWarnings("unchecked") + @Test + public void testArrayEmbed() throws ParseException { + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); + var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", + new SimpleLinguistics(), + new MockEmbedder("myDocument.myTensorArray")); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); + + var tensorField = new Field("myTensorArray", new ArrayDataType(new TensorDataType(tensorType))); + adapter.createField(tensorField); + + var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING)); + array.add(new StringFieldValue("first")); + array.add(new StringFieldValue("second")); + adapter.setValue("myTextArray", array); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new ArrayDataType(new TensorDataType(tensorType)), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(array); + expression.execute(context); + assertTrue(adapter.values.containsKey("myTensorArray")); + var tensorArray = (Array<TensorFieldValue>)adapter.values.get("myTensorArray"); + assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), tensorArray.get(0).getTensor().get()); + assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), tensorArray.get(1).getTensor().get()); + } + private static class MockEmbedder implements Embedder { private final String expectedDestination; |