diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java | 73 |
1 files changed, 66 insertions, 7 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index 6769f05bb3e..7f4cee07e8c 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -2,9 +2,16 @@ package com.yahoo.search.query.properties; import com.yahoo.api.annotations.Beta; +import com.yahoo.language.process.Embedder; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; +import com.yahoo.search.config.SchemaInfo; +import com.yahoo.search.config.internal.TensorConverter; import com.yahoo.search.query.Properties; +import com.yahoo.search.query.Ranking; +import com.yahoo.search.query.ranking.RankFeatures; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.Map; @@ -16,19 +23,71 @@ import java.util.Map; @Beta public class RankProfileInputProperties extends Properties { + private final SchemaInfo schemaInfo; private final Query query; + private final TensorConverter tensorConverter; - public RankProfileInputProperties(Query query) { + private SchemaInfo.Session session = null; + + public RankProfileInputProperties(SchemaInfo schemaInfo, Query query, Map<String, Embedder> embedders) { + this.schemaInfo = schemaInfo; this.query = query; + this.tensorConverter = new TensorConverter(embedders); + } + + @Override + public void set(CompoundName name, Object value, Map<String, String> context) { + if (RankFeatures.isFeatureName(name.toString())) { + TensorType expectedType = typeOf(name); + if (expectedType != null) { + try { + value = tensorConverter.convertTo(expectedType, + name.last(), + value, + query.getModel().getLanguage()); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'", e); + } + } + } + super.set(name, value, context); } - /** - * Throws IllegalInputException if the given key cannot be set to the given value. - * This default implementation just passes to the chained properties, if any. - */ + @Override public void requireSettable(CompoundName name, Object value, Map<String, String> context) { - if (chained() != null) - chained().requireSettable(name, value, context); + if (RankFeatures.isFeatureName(name.toString())) { + TensorType expectedType = typeOf(name); + if (expectedType != null) + verifyType(name, value, expectedType); + } + super.requireSettable(name, value, context); + } + + private TensorType typeOf(CompoundName name) { + // Session is lazily resolved because order matters: + // model.sources+restrict must be set in the query before this is done + if (session == null) + session = schemaInfo.newSession(query); + // In addition, the rank profile must be set before features + return session.rankProfileInput(name.last(), query.getRanking().getProfile()); + } + + private void verifyType(CompoundName name, Object value, TensorType expectedType) { + if (value instanceof Tensor) { + TensorType valueType = ((Tensor)value).type(); + if ( ! valueType.isAssignableTo(expectedType)) + throwIllegalInput(name, value, expectedType); + } + else if (expectedType.rank() > 0) { // rank 0 tensor may also be represented as a scalar or string + throwIllegalInput(name, value, expectedType); + } + } + + private void throwIllegalInput(CompoundName name, Object value, TensorType expectedType) { + throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "': " + + "This input is declared in rank profile '" + query.getRanking().getProfile() + + "' as " + expectedType); } } |