summaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java')
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java73
1 files changed, 66 insertions, 7 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java
index 6769f05bb3e..7f4cee07e8c 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java
@@ -2,9 +2,16 @@
package com.yahoo.search.query.properties;
import com.yahoo.api.annotations.Beta;
+import com.yahoo.language.process.Embedder;
import com.yahoo.processing.request.CompoundName;
import com.yahoo.search.Query;
+import com.yahoo.search.config.SchemaInfo;
+import com.yahoo.search.config.internal.TensorConverter;
import com.yahoo.search.query.Properties;
+import com.yahoo.search.query.Ranking;
+import com.yahoo.search.query.ranking.RankFeatures;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import java.util.Map;
@@ -16,19 +23,71 @@ import java.util.Map;
@Beta
public class RankProfileInputProperties extends Properties {
+ private final SchemaInfo schemaInfo;
private final Query query;
+ private final TensorConverter tensorConverter;
- public RankProfileInputProperties(Query query) {
+ private SchemaInfo.Session session = null;
+
+ public RankProfileInputProperties(SchemaInfo schemaInfo, Query query, Map<String, Embedder> embedders) {
+ this.schemaInfo = schemaInfo;
this.query = query;
+ this.tensorConverter = new TensorConverter(embedders);
+ }
+
+ @Override
+ public void set(CompoundName name, Object value, Map<String, String> context) {
+ if (RankFeatures.isFeatureName(name.toString())) {
+ TensorType expectedType = typeOf(name);
+ if (expectedType != null) {
+ try {
+ value = tensorConverter.convertTo(expectedType,
+ name.last(),
+ value,
+ query.getModel().getLanguage());
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'", e);
+ }
+ }
+ }
+ super.set(name, value, context);
}
- /**
- * Throws IllegalInputException if the given key cannot be set to the given value.
- * This default implementation just passes to the chained properties, if any.
- */
+ @Override
public void requireSettable(CompoundName name, Object value, Map<String, String> context) {
- if (chained() != null)
- chained().requireSettable(name, value, context);
+ if (RankFeatures.isFeatureName(name.toString())) {
+ TensorType expectedType = typeOf(name);
+ if (expectedType != null)
+ verifyType(name, value, expectedType);
+ }
+ super.requireSettable(name, value, context);
+ }
+
+ private TensorType typeOf(CompoundName name) {
+ // Session is lazily resolved because order matters:
+ // model.sources+restrict must be set in the query before this is done
+ if (session == null)
+ session = schemaInfo.newSession(query);
+ // In addition, the rank profile must be set before features
+ return session.rankProfileInput(name.last(), query.getRanking().getProfile());
+ }
+
+ private void verifyType(CompoundName name, Object value, TensorType expectedType) {
+ if (value instanceof Tensor) {
+ TensorType valueType = ((Tensor)value).type();
+ if ( ! valueType.isAssignableTo(expectedType))
+ throwIllegalInput(name, value, expectedType);
+ }
+ else if (expectedType.rank() > 0) { // rank 0 tensor may also be represented as a scalar or string
+ throwIllegalInput(name, value, expectedType);
+ }
+ }
+
+ private void throwIllegalInput(CompoundName name, Object value, TensorType expectedType) {
+ throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "': " +
+ "This input is declared in rank profile '" + query.getRanking().getProfile() +
+ "' as " + expectedType);
}
}