diff options
Diffstat (limited to 'container-search/src')
6 files changed, 63 insertions, 18 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/Properties.java b/container-search/src/main/java/com/yahoo/search/query/Properties.java index 12a82afc7bd..d4fc4d57cd6 100644 --- a/container-search/src/main/java/com/yahoo/search/query/Properties.java +++ b/container-search/src/main/java/com/yahoo/search/query/Properties.java @@ -1,8 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query; +import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; +import java.util.Map; + /** * Object properties keyed by name which can be looked up using default values and * with conversion to various primitive wrapper types. @@ -50,4 +53,13 @@ public abstract class Properties extends com.yahoo.processing.request.Properties chained().setParentQuery(query); } + /** + * Throws IllegalInputException if the given key cannot be set to the given value. + * This default implementation just passes to the chained properties, if any. + */ + public void requireSettable(CompoundName name, Object value, Map<String, String> context) { + if (chained() != null) + chained().requireSettable(name, value, context); + } + } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java index 5b3758f103d..19e0e441359 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java @@ -14,6 +14,7 @@ import com.yahoo.search.query.profile.types.ConversionContext; import com.yahoo.search.query.profile.types.FieldDescription; import com.yahoo.search.query.profile.types.QueryProfileFieldType; import com.yahoo.search.query.profile.types.QueryProfileType; +import com.yahoo.tensor.Tensor; import java.util.ArrayList; import java.util.Collections; @@ -91,6 +92,15 @@ public class QueryProfileProperties extends Properties { */ @Override public void set(CompoundName name, Object value, Map<String, String> context) { + setOrCheckSettable(name, value, context, true); + } + + @Override + public void requireSettable(CompoundName name, Object value, Map<String, String> context) { + setOrCheckSettable(name, value, context, false); + } + + private void setOrCheckSettable(CompoundName name, Object value, Map<String, String> context, boolean set) { try { name = unalias(name, context); @@ -110,29 +120,36 @@ public class QueryProfileProperties extends Properties { if (value instanceof String && value.toString().startsWith("ref:")) { if (profile.getRegistry() == null) throw new IllegalInputException("Runtime query profile references does not work when the " + - "QueryProfileProperties are constructed without a registry"); + "QueryProfileProperties are constructed without a registry"); String queryProfileId = value.toString().substring(4); value = profile.getRegistry().findQueryProfile(queryProfileId); if (value == null) throw new IllegalInputException("Query profile '" + queryProfileId + "' is not found"); } - if (value instanceof CompiledQueryProfile) { // this will be due to one of the two clauses above - if (references == null) - references = new ArrayList<>(); - references.add(0, new Pair<>(name, (CompiledQueryProfile)value)); // references set later has precedence - put first - } - else { - if (values == null) - values = new HashMap<>(); - values.put(name, value); + if (set) { + if (value instanceof CompiledQueryProfile) { // this will be due to one of the two clauses above + if (references == null) + references = new ArrayList<>(); + // references set later has precedence - put first + references.add(0, new Pair<>(name, (CompiledQueryProfile) value)); + } else { + if (values == null) + values = new HashMap<>(); + values.put(name, value); + } } } catch (IllegalArgumentException e) { - throw new IllegalInputException("Could not set '" + name + "' to '" + value + "'", e); + throw new IllegalInputException("Could not set '" + name + "' to '" + toShortString(value) + "'", e); } } + private String toShortString(Object value) { + if ( ! (value instanceof Tensor)) return value.toString(); + return ((Tensor)value).toShortString(); + } + private Object convertByType(CompoundName name, Object value, Map<String, String> context) { QueryProfileType type; QueryProfileType explicitTypeFromField = null; diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index 6f1cfccc16b..db6a58a4dd3 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -51,7 +51,15 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { - if (o instanceof Tensor) return o; + Tensor tensor = toTensor(o, context); + if (tensor == null) return null; + if (! tensor.type().isAssignableTo(type)) + throw new IllegalArgumentException("Require a tensor of type " + type); + return tensor; + } + + private Tensor toTensor(Object o, ConversionContext context) { + if (o instanceof Tensor) return (Tensor)o; if (o instanceof String && ((String)o).startsWith("embed(")) return encode((String)o, context); if (o instanceof String) return Tensor.from(type, (String)o); return null; diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 98b65c6edd9..2c0f5dc8bea 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -322,6 +322,7 @@ public class QueryProperties extends Properties { } } else if (key.first().equals("rankfeature") || key.first().equals("featureoverride") ) { // featureoverride is deprecated + chained().requireSettable(key, value, context); setRankingFeature(query, key.rest().toString(), toSpecifiedType(key.rest().toString(), value, profileRegistry.getTypeRegistry().getComponent("features"), diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java index 1a4ecb4ecd8..807d70739cc 100644 --- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java +++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java @@ -2,8 +2,10 @@ package com.yahoo.search.query.ranking; import com.yahoo.fs4.MapEncoder; +import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; import com.yahoo.search.query.Ranking; +import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.tensor.Tensor; import com.yahoo.text.JSON; @@ -47,9 +49,14 @@ public class RankFeatures implements Cloneable { /** Sets a tensor rank feature */ public void put(String name, Tensor value) { + verifyType(name, value); features.put(name, value); } + private void verifyType(String name, Object value) { + parent.getParent().properties().requireSettable(new CompoundName(List.of("ranking", "features", name)), value, Map.of()); + } + /** * Sets a rank feature to a value represented as a string. * diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index 20678f3b7bb..a77de954b3a 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -20,7 +20,6 @@ import com.yahoo.search.query.profile.types.FieldType; import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry; import org.junit.Before; -import org.junit.Ignore; import org.junit.Test; import java.net.URLEncoder; @@ -439,7 +438,6 @@ public class QueryProfileTypeTestCase { } @Test - @Ignore public void testTensorRankFeatureSetProgrammaticallyWithWrongType() { QueryProfile profile = new QueryProfile("test"); profile.setType(testtype); @@ -454,16 +452,18 @@ public class QueryProfileTypeTestCase { fail("Expected exception"); } catch (IllegalArgumentException e) { - assertEquals("'query(myTensor1)' must be of type tensor(a{},b{}) but was of type tensor(x[3])", - e.getMessage()); + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "Require a tensor of type tensor(a{},b{})", + Exceptions.toMessageString(e)); } try { query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString)); fail("Expected exception"); } catch (IllegalArgumentException e) { - assertEquals("'query(myTensor1)' must be of type tensor(a{},b{}) but was of type tensor(x[3])", - e.getMessage()); + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "Require a tensor of type tensor(a{},b{})", + Exceptions.toMessageString(e)); } } |