aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorHarald Musum <musum@vespa.ai>2023-11-01 10:51:08 +0100
committerGitHub <noreply@github.com>2023-11-01 10:51:08 +0100
commit465e1212b7153b6fec914ab91c2b06052a42ebe8 (patch)
tree75dcb1fe9d1218cafcc7ab8a97326ee4ec9b8e57 /config-model/src
parentccd99386c4e2b814bc1e608accafa98fe7b0020a (diff)
Revert "validate for array/wset attributes"
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java24
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java28
2 files changed, 4 insertions, 48 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index 188918e99e5..1ff85c9c89f 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -1265,34 +1265,14 @@ public class RankProfile implements Cloneable {
return Optional.empty(); // if this context does not contain this input
}
- private static class AttributeErrorType extends TensorType {
- private final String attr;
- private final Attribute.CollectionType collType;
- AttributeErrorType(String attr, Attribute.CollectionType collType) {
- super(TensorType.Value.INT8, List.of());
- this.attr = attr;
- this.collType = collType;
- }
- private void doThrow() {
- throw new IllegalArgumentException("Cannot use attribute(" + attr +") " + collType + " as ranking expression input");
- }
- @Override public TensorType.Value valueType() { doThrow(); return null; }
- @Override public int rank() { doThrow(); return 0; }
- @Override public List<TensorType.Dimension> dimensions() { doThrow(); return null; }
- }
-
private void addAttributeFeatureTypes(ImmutableSDField field, Map<Reference, TensorType> featureTypes) {
Attribute attribute = field.getAttribute();
field.getAttributes().forEach((k, a) -> {
String name = k;
if (attribute == a) // this attribute should take the fields name
name = field.getName(); // switch to that - it is separate for imported fields
- if (a.getCollectionType().equals(Attribute.CollectionType.SINGLE)) {
- featureTypes.put(FeatureNames.asAttributeFeature(name),
- a.tensorType().orElse(TensorType.empty));
- } else {
- featureTypes.put(FeatureNames.asAttributeFeature(name), new AttributeErrorType(name, a.getCollectionType()));
- }
+ featureTypes.put(FeatureNames.asAttributeFeature(name),
+ a.tensorType().orElse(TensorType.empty));
});
}
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java
index 8c5f90e4b7f..58a0b54e6cc 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java
@@ -19,14 +19,12 @@ import com.yahoo.schema.AbstractSchemaTestCase;
import com.yahoo.schema.derived.AttributeFields;
import com.yahoo.schema.derived.RawRankProfile;
import com.yahoo.schema.parser.ParseException;
-import com.yahoo.yolean.Exceptions;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;
public class TensorTransformTestCase extends AbstractSchemaTestCase {
@@ -37,6 +35,8 @@ public class TensorTransformTestCase extends AbstractSchemaTestCase {
"max(1.0,2.0)");
assertTransformedExpression("min(attribute(double_field),x)",
"min(attribute(double_field),x)");
+ assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))",
+ "max(attribute(double_field),attribute(double_array_field))");
assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))",
"min(attribute(tensor_field_1),attribute(double_field))");
assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)",
@@ -54,30 +54,6 @@ public class TensorTransformTestCase extends AbstractSchemaTestCase {
}
@Test
- void requireThatUsingArrayFails() throws ParseException {
- Throwable e = assertThrows(IllegalArgumentException.class, () -> {
- buildSearch("max(attribute(double_field),attribute(double_array_field))");
- });
- String msg = Exceptions.toMessageString(e);
- assertEquals("In schema 'test', rank profile 'test':" +
- " The function 'testexpression' is invalid: attribute(double_array_field) is invalid:" +
- " Cannot use attribute(double_array_field) collectiontype: ARRAY as ranking expression input",
- msg);
- }
-
- @Test
- void requireThatUsingWsetFails() throws ParseException {
- Throwable e = assertThrows(IllegalArgumentException.class, () -> {
- buildSearch("map(attribute(weightedset_field), f(x)(x+3))");
- });
- String msg = Exceptions.toMessageString(e);
- assertEquals("In schema 'test', rank profile 'test':" +
- " The function 'testexpression' is invalid: attribute(weightedset_field) is invalid:" +
- " Cannot use attribute(weightedset_field) collectiontype: WEIGHTEDSET as ranking expression input",
- msg);
- }
-
- @Test
void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException {
assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)",
"max(attribute(tensor_field_1),x)");