summaryrefslogtreecommitdiffstats
path: root/container-search/src
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src')
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/Properties.java12
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java39
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java10
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java1
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java7
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java12
6 files changed, 63 insertions, 18 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/Properties.java b/container-search/src/main/java/com/yahoo/search/query/Properties.java
index 12a82afc7bd..d4fc4d57cd6 100644
--- a/container-search/src/main/java/com/yahoo/search/query/Properties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/Properties.java
@@ -1,8 +1,11 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.query;
+import com.yahoo.processing.request.CompoundName;
import com.yahoo.search.Query;
+import java.util.Map;
+
/**
* Object properties keyed by name which can be looked up using default values and
* with conversion to various primitive wrapper types.
@@ -50,4 +53,13 @@ public abstract class Properties extends com.yahoo.processing.request.Properties
chained().setParentQuery(query);
}
+ /**
+ * Throws IllegalInputException if the given key cannot be set to the given value.
+ * This default implementation just passes to the chained properties, if any.
+ */
+ public void requireSettable(CompoundName name, Object value, Map<String, String> context) {
+ if (chained() != null)
+ chained().requireSettable(name, value, context);
+ }
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
index 5b3758f103d..19e0e441359 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
@@ -14,6 +14,7 @@ import com.yahoo.search.query.profile.types.ConversionContext;
import com.yahoo.search.query.profile.types.FieldDescription;
import com.yahoo.search.query.profile.types.QueryProfileFieldType;
import com.yahoo.search.query.profile.types.QueryProfileType;
+import com.yahoo.tensor.Tensor;
import java.util.ArrayList;
import java.util.Collections;
@@ -91,6 +92,15 @@ public class QueryProfileProperties extends Properties {
*/
@Override
public void set(CompoundName name, Object value, Map<String, String> context) {
+ setOrCheckSettable(name, value, context, true);
+ }
+
+ @Override
+ public void requireSettable(CompoundName name, Object value, Map<String, String> context) {
+ setOrCheckSettable(name, value, context, false);
+ }
+
+ private void setOrCheckSettable(CompoundName name, Object value, Map<String, String> context, boolean set) {
try {
name = unalias(name, context);
@@ -110,29 +120,36 @@ public class QueryProfileProperties extends Properties {
if (value instanceof String && value.toString().startsWith("ref:")) {
if (profile.getRegistry() == null)
throw new IllegalInputException("Runtime query profile references does not work when the " +
- "QueryProfileProperties are constructed without a registry");
+ "QueryProfileProperties are constructed without a registry");
String queryProfileId = value.toString().substring(4);
value = profile.getRegistry().findQueryProfile(queryProfileId);
if (value == null)
throw new IllegalInputException("Query profile '" + queryProfileId + "' is not found");
}
- if (value instanceof CompiledQueryProfile) { // this will be due to one of the two clauses above
- if (references == null)
- references = new ArrayList<>();
- references.add(0, new Pair<>(name, (CompiledQueryProfile)value)); // references set later has precedence - put first
- }
- else {
- if (values == null)
- values = new HashMap<>();
- values.put(name, value);
+ if (set) {
+ if (value instanceof CompiledQueryProfile) { // this will be due to one of the two clauses above
+ if (references == null)
+ references = new ArrayList<>();
+ // references set later has precedence - put first
+ references.add(0, new Pair<>(name, (CompiledQueryProfile) value));
+ } else {
+ if (values == null)
+ values = new HashMap<>();
+ values.put(name, value);
+ }
}
}
catch (IllegalArgumentException e) {
- throw new IllegalInputException("Could not set '" + name + "' to '" + value + "'", e);
+ throw new IllegalInputException("Could not set '" + name + "' to '" + toShortString(value) + "'", e);
}
}
+ private String toShortString(Object value) {
+ if ( ! (value instanceof Tensor)) return value.toString();
+ return ((Tensor)value).toShortString();
+ }
+
private Object convertByType(CompoundName name, Object value, Map<String, String> context) {
QueryProfileType type;
QueryProfileType explicitTypeFromField = null;
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index 6f1cfccc16b..db6a58a4dd3 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -51,7 +51,15 @@ public class TensorFieldType extends FieldType {
@Override
public Object convertFrom(Object o, ConversionContext context) {
- if (o instanceof Tensor) return o;
+ Tensor tensor = toTensor(o, context);
+ if (tensor == null) return null;
+ if (! tensor.type().isAssignableTo(type))
+ throw new IllegalArgumentException("Require a tensor of type " + type);
+ return tensor;
+ }
+
+ private Tensor toTensor(Object o, ConversionContext context) {
+ if (o instanceof Tensor) return (Tensor)o;
if (o instanceof String && ((String)o).startsWith("embed(")) return encode((String)o, context);
if (o instanceof String) return Tensor.from(type, (String)o);
return null;
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index 98b65c6edd9..2c0f5dc8bea 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -322,6 +322,7 @@ public class QueryProperties extends Properties {
}
}
else if (key.first().equals("rankfeature") || key.first().equals("featureoverride") ) { // featureoverride is deprecated
+ chained().requireSettable(key, value, context);
setRankingFeature(query, key.rest().toString(), toSpecifiedType(key.rest().toString(),
value,
profileRegistry.getTypeRegistry().getComponent("features"),
diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java
index 1a4ecb4ecd8..807d70739cc 100644
--- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java
+++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java
@@ -2,8 +2,10 @@
package com.yahoo.search.query.ranking;
import com.yahoo.fs4.MapEncoder;
+import com.yahoo.processing.request.CompoundName;
import com.yahoo.search.Query;
import com.yahoo.search.query.Ranking;
+import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.tensor.Tensor;
import com.yahoo.text.JSON;
@@ -47,9 +49,14 @@ public class RankFeatures implements Cloneable {
/** Sets a tensor rank feature */
public void put(String name, Tensor value) {
+ verifyType(name, value);
features.put(name, value);
}
+ private void verifyType(String name, Object value) {
+ parent.getParent().properties().requireSettable(new CompoundName(List.of("ranking", "features", name)), value, Map.of());
+ }
+
/**
* Sets a rank feature to a value represented as a string.
*
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
index 20678f3b7bb..a77de954b3a 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
@@ -20,7 +20,6 @@ import com.yahoo.search.query.profile.types.FieldType;
import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry;
import org.junit.Before;
-import org.junit.Ignore;
import org.junit.Test;
import java.net.URLEncoder;
@@ -439,7 +438,6 @@ public class QueryProfileTypeTestCase {
}
@Test
- @Ignore
public void testTensorRankFeatureSetProgrammaticallyWithWrongType() {
QueryProfile profile = new QueryProfile("test");
profile.setType(testtype);
@@ -454,16 +452,18 @@ public class QueryProfileTypeTestCase {
fail("Expected exception");
}
catch (IllegalArgumentException e) {
- assertEquals("'query(myTensor1)' must be of type tensor(a{},b{}) but was of type tensor(x[3])",
- e.getMessage());
+ assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " +
+ "Require a tensor of type tensor(a{},b{})",
+ Exceptions.toMessageString(e));
}
try {
query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString));
fail("Expected exception");
}
catch (IllegalArgumentException e) {
- assertEquals("'query(myTensor1)' must be of type tensor(a{},b{}) but was of type tensor(x[3])",
- e.getMessage());
+ assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " +
+ "Require a tensor of type tensor(a{},b{})",
+ Exceptions.toMessageString(e));
}
}