diff options
author | Harald Musum <musum@vespa.ai> | 2023-11-01 10:51:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-01 10:51:08 +0100 |
commit | 465e1212b7153b6fec914ab91c2b06052a42ebe8 (patch) | |
tree | 75dcb1fe9d1218cafcc7ab8a97326ee4ec9b8e57 /config-model/src | |
parent | ccd99386c4e2b814bc1e608accafa98fe7b0020a (diff) |
Revert "validate for array/wset attributes"
Diffstat (limited to 'config-model/src')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/RankProfile.java | 24 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java | 28 |
2 files changed, 4 insertions, 48 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 188918e99e5..1ff85c9c89f 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1265,34 +1265,14 @@ public class RankProfile implements Cloneable { return Optional.empty(); // if this context does not contain this input } - private static class AttributeErrorType extends TensorType { - private final String attr; - private final Attribute.CollectionType collType; - AttributeErrorType(String attr, Attribute.CollectionType collType) { - super(TensorType.Value.INT8, List.of()); - this.attr = attr; - this.collType = collType; - } - private void doThrow() { - throw new IllegalArgumentException("Cannot use attribute(" + attr +") " + collType + " as ranking expression input"); - } - @Override public TensorType.Value valueType() { doThrow(); return null; } - @Override public int rank() { doThrow(); return 0; } - @Override public List<TensorType.Dimension> dimensions() { doThrow(); return null; } - } - private void addAttributeFeatureTypes(ImmutableSDField field, Map<Reference, TensorType> featureTypes) { Attribute attribute = field.getAttribute(); field.getAttributes().forEach((k, a) -> { String name = k; if (attribute == a) // this attribute should take the fields name name = field.getName(); // switch to that - it is separate for imported fields - if (a.getCollectionType().equals(Attribute.CollectionType.SINGLE)) { - featureTypes.put(FeatureNames.asAttributeFeature(name), - a.tensorType().orElse(TensorType.empty)); - } else { - featureTypes.put(FeatureNames.asAttributeFeature(name), new AttributeErrorType(name, a.getCollectionType())); - } + featureTypes.put(FeatureNames.asAttributeFeature(name), + a.tensorType().orElse(TensorType.empty)); }); } diff --git a/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java index 8c5f90e4b7f..58a0b54e6cc 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java @@ -19,14 +19,12 @@ import com.yahoo.schema.AbstractSchemaTestCase; import com.yahoo.schema.derived.AttributeFields; import com.yahoo.schema.derived.RawRankProfile; import com.yahoo.schema.parser.ParseException; -import com.yahoo.yolean.Exceptions; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import org.junit.jupiter.api.Test; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; public class TensorTransformTestCase extends AbstractSchemaTestCase { @@ -37,6 +35,8 @@ public class TensorTransformTestCase extends AbstractSchemaTestCase { "max(1.0,2.0)"); assertTransformedExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)"); + assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))", + "max(attribute(double_field),attribute(double_array_field))"); assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))"); assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)", @@ -54,30 +54,6 @@ public class TensorTransformTestCase extends AbstractSchemaTestCase { } @Test - void requireThatUsingArrayFails() throws ParseException { - Throwable e = assertThrows(IllegalArgumentException.class, () -> { - buildSearch("max(attribute(double_field),attribute(double_array_field))"); - }); - String msg = Exceptions.toMessageString(e); - assertEquals("In schema 'test', rank profile 'test':" + - " The function 'testexpression' is invalid: attribute(double_array_field) is invalid:" + - " Cannot use attribute(double_array_field) collectiontype: ARRAY as ranking expression input", - msg); - } - - @Test - void requireThatUsingWsetFails() throws ParseException { - Throwable e = assertThrows(IllegalArgumentException.class, () -> { - buildSearch("map(attribute(weightedset_field), f(x)(x+3))"); - }); - String msg = Exceptions.toMessageString(e); - assertEquals("In schema 'test', rank profile 'test':" + - " The function 'testexpression' is invalid: attribute(weightedset_field) is invalid:" + - " Cannot use attribute(weightedset_field) collectiontype: WEIGHTEDSET as ranking expression input", - msg); - } - - @Test void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)", "max(attribute(tensor_field_1),x)"); |