diff options
Diffstat (limited to 'container-search/src')
26 files changed, 1185 insertions, 267 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/IndexFacts.java b/container-search/src/main/java/com/yahoo/prelude/IndexFacts.java index a1613dccdd5..c9a855c2f34 100644 --- a/container-search/src/main/java/com/yahoo/prelude/IndexFacts.java +++ b/container-search/src/main/java/com/yahoo/prelude/IndexFacts.java @@ -323,11 +323,11 @@ public class IndexFacts { private Session(Collection<String> sources, Collection<String> restrict) { // Assumption: Search definition name equals document name. - documentTypes = ImmutableList.copyOf(resolveDocumentTypes(sources, restrict, searchDefinitions.keySet())); + documentTypes = List.copyOf(resolveDocumentTypes(sources, restrict, searchDefinitions.keySet())); } private Session(Collection<String> sources, Collection<String> restrict, Set<String> candidateDocumentTypes) { - documentTypes = ImmutableList.copyOf(resolveDocumentTypes(sources, restrict, candidateDocumentTypes)); + documentTypes = List.copyOf(resolveDocumentTypes(sources, restrict, candidateDocumentTypes)); } /** diff --git a/container-search/src/main/java/com/yahoo/prelude/IndexModel.java b/container-search/src/main/java/com/yahoo/prelude/IndexModel.java index e7018f81de1..57a8d518ed2 100644 --- a/container-search/src/main/java/com/yahoo/prelude/IndexModel.java +++ b/container-search/src/main/java/com/yahoo/prelude/IndexModel.java @@ -24,9 +24,9 @@ public final class IndexModel { private static final Logger log = Logger.getLogger(IndexModel.class.getName()); - private Map<String, List<String>> masterClusters; - private Map<String, SearchDefinition> searchDefinitions; - private SearchDefinition unionSearchDefinition; + private final Map<String, List<String>> masterClusters; + private final Map<String, SearchDefinition> searchDefinitions; + private final SearchDefinition unionSearchDefinition; /** Create an index model for a single search definition */ public IndexModel(SearchDefinition searchDefinition) { @@ -83,7 +83,6 @@ public final class IndexModel { return clusters; } - @SuppressWarnings("deprecation") private static Map<String, SearchDefinition> toSearchDefinitions(IndexInfoConfig c) { Map<String, SearchDefinition> searchDefinitions = new HashMap<>(); diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocumentDatabase.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocumentDatabase.java index d5a3e8c2786..f35559ad2f4 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocumentDatabase.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocumentDatabase.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.prelude.fastsearch; +import com.yahoo.search.config.RankProfile; import com.yahoo.tensor.TensorType; import java.util.ArrayList; @@ -10,7 +11,7 @@ import java.util.Map; import java.util.stream.Collectors; /** - * Representation of a back-end document database. + * Representation of a document database realizing a schema in a content cluster. * * @author geirst */ @@ -33,7 +34,7 @@ public class DocumentDatabase { public DocumentDatabase(String name, DocsumDefinitionSet docsumDefinitionSet, Collection<RankProfile> rankProfiles) { this.name = name; this.docsumDefSet = docsumDefinitionSet; - this.rankProfiles = Map.copyOf(rankProfiles.stream().collect(Collectors.toMap(RankProfile::getName, p -> p))); + this.rankProfiles = Map.copyOf(rankProfiles.stream().collect(Collectors.toMap(RankProfile::name, p -> p))); } public String getName() { @@ -49,13 +50,15 @@ public class DocumentDatabase { private static Collection<RankProfile> toRankProfiles(Collection<DocumentdbInfoConfig.Documentdb.Rankprofile> rankProfileConfigList) { List<RankProfile> rankProfiles = new ArrayList<>(); - for (DocumentdbInfoConfig.Documentdb.Rankprofile c : rankProfileConfigList) - rankProfiles.add(new RankProfile(c.name(), c.hasSummaryFeatures(), c.hasRankFeatures(), inputs(c))); + for (var profileConfig : rankProfileConfigList) { + var builder = new RankProfile.Builder(profileConfig.name()); + builder.setHasSummaryFeatures(profileConfig.hasSummaryFeatures()); + builder.setHasRankFeatures(profileConfig.hasRankFeatures()); + for (var inputConfig : profileConfig.input()) + builder.addInput(inputConfig.name(), TensorType.fromSpec(inputConfig.type())); + rankProfiles.add(builder.build()); + } return rankProfiles; } - private static Map<String, TensorType> inputs(DocumentdbInfoConfig.Documentdb.Rankprofile c) { - return c.input().stream().collect(Collectors.toMap(i -> i.name(), i -> TensorType.fromSpec(i.type()))); - } - } diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/RankProfile.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/RankProfile.java deleted file mode 100644 index a4248245f2a..00000000000 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/RankProfile.java +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.prelude.fastsearch; - -import com.yahoo.tensor.TensorType; - -import java.util.Map; - -/** - * Information about a rank profile - * - * @author bratseth - */ -class RankProfile { - - private final String name; - private final boolean hasSummaryFeatures; - private final boolean hasRankFeatures; - private final Map<String, TensorType> inputs; - - public RankProfile(String name, - boolean hasSummaryFeatures, - boolean hasRankFeatures, - Map<String, TensorType> inputs) { - this.name = name; - this.hasSummaryFeatures = hasSummaryFeatures; - this.hasRankFeatures = hasRankFeatures; - this.inputs = Map.copyOf(inputs); - } - - public String getName() { return name; } - - /** Returns true if this rank profile has summary features. */ - public boolean hasSummaryFeatures() { return hasSummaryFeatures; } - - /** Returns true if this rank profile has rank features. */ - public boolean hasRankFeatures() { return hasRankFeatures; } - - /** Returns the inputs explicitly declared in this rank profile. */ - public Map<String, TensorType> inputs() { return inputs; } - -} diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java index d26791411c5..a6da823d990 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java @@ -14,8 +14,8 @@ import com.yahoo.protect.Validator; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.cluster.PingableSearcher; +import com.yahoo.search.config.RankProfile; import com.yahoo.search.grouping.vespa.GroupingExecutor; -import com.yahoo.search.result.ErrorHit; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.Hit; import com.yahoo.search.searchchain.Execution; @@ -33,7 +33,7 @@ import java.util.logging.Logger; /** * Superclass for backend searchers. * - * @author baldersheim + * @author baldersheim */ public abstract class VespaBackEndSearcher extends PingableSearcher { diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index 7def8d3ab3c..6998826cdfd 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -13,6 +13,7 @@ import com.yahoo.prelude.fastsearch.DocumentDatabase; import com.yahoo.prelude.query.Highlight; import com.yahoo.prelude.query.textualrepresentation.TextualQueryRepresentation; import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.config.SchemaInfo; import com.yahoo.search.dispatch.Dispatcher; import com.yahoo.search.dispatch.rpc.ProtobufSerialization; import com.yahoo.search.federation.FederationSearcher; @@ -45,6 +46,7 @@ import com.yahoo.search.query.properties.QueryProperties; import com.yahoo.search.query.properties.QueryPropertyAliases; import com.yahoo.search.query.properties.RankProfileInputProperties; import com.yahoo.search.query.properties.RequestContextProperties; +import com.yahoo.search.query.ranking.RankFeatures; import com.yahoo.search.yql.NullItemException; import com.yahoo.search.yql.VespaSerializer; import com.yahoo.search.yql.YqlParser; @@ -177,7 +179,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { /** The ranking requested in this query */ private Ranking ranking = new Ranking(this); - /** The query query and/or query program declaration */ + /** The query and/or query program declaration */ private Model model = new Model(this); /** How results of this query should be presented */ @@ -212,6 +214,8 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { argumentType = new QueryProfileType("native"); argumentType.setBuiltin(true); + // Note: Order here matters as fields are set in this order, and rank feature conversion depends + // on other fields already being set (see RankProfileInputProperties) argumentType.addField(new FieldDescription(OFFSET.toString(), "integer", "offset start")); argumentType.addField(new FieldDescription(HITS.toString(), "integer", "hits count")); argumentType.addField(new FieldDescription(QUERY_PROFILE.toString(), "string")); @@ -223,11 +227,11 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { argumentType.addField(new FieldDescription(TIMEOUT.toString(), "string", "timeout")); argumentType.addField(new FieldDescription(FederationSearcher.SOURCENAME.toString(),"string")); argumentType.addField(new FieldDescription(FederationSearcher.PROVIDERNAME.toString(),"string")); - argumentType.addField(new FieldDescription(Presentation.PRESENTATION, new QueryProfileFieldType(Presentation.getArgumentType()))); - argumentType.addField(new FieldDescription(Ranking.RANKING, new QueryProfileFieldType(Ranking.getArgumentType()))); argumentType.addField(new FieldDescription(Model.MODEL, new QueryProfileFieldType(Model.getArgumentType()))); argumentType.addField(new FieldDescription(Select.SELECT, new QueryProfileFieldType(Select.getArgumentType()))); argumentType.addField(new FieldDescription(Dispatcher.DISPATCH, new QueryProfileFieldType(Dispatcher.getArgumentType()))); + argumentType.addField(new FieldDescription(Ranking.RANKING, new QueryProfileFieldType(Ranking.getArgumentType()))); + argumentType.addField(new FieldDescription(Presentation.PRESENTATION, new QueryProfileFieldType(Presentation.getArgumentType()))); argumentType.freeze(); } public static QueryProfileType getArgumentType() { return argumentType; } @@ -259,9 +263,9 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public static void addNativeQueryProfileTypesTo(QueryProfileTypeRegistry registry) { // Add modifiable copies to allow query profile types in this to add to these registry.register(Query.getArgumentType().unfrozen()); - registry.register(Ranking.getArgumentType().unfrozen()); registry.register(Model.getArgumentType().unfrozen()); registry.register(Select.getArgumentType().unfrozen()); + registry.register(Ranking.getArgumentType().unfrozen()); registry.register(Presentation.getArgumentType().unfrozen()); registry.register(DefaultProperties.argumentType.unfrozen()); } @@ -271,7 +275,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { ImmutableList.copyOf(namesUnder(CompoundName.empty, Query.getArgumentType())); private static List<CompoundName> namesUnder(CompoundName prefix, QueryProfileType type) { - if ( type == null) return Collections.emptyList(); // Names not known statically + if (type == null) return Collections.emptyList(); // Names not known statically List<CompoundName> names = new ArrayList<>(); for (Map.Entry<String, FieldDescription> field : type.fields().entrySet()) { if (field.getValue().getType() instanceof QueryProfileFieldType) { @@ -339,7 +343,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo()); + init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo(), SchemaInfo.empty()); } // TODO: Deprecate most constructors above here @@ -349,23 +353,26 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { builder.getRequestMap(), builder.getQueryProfile(), builder.getEmbedders(), - builder.getZoneInfo()); + builder.getZoneInfo(), + builder.getSchemaInfo()); } private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders, - ZoneInfo zoneInfo) { + ZoneInfo zoneInfo, + SchemaInfo schemaInfo) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, embedders, zoneInfo); + init(requestMap, queryProfile, embedders, zoneInfo, schemaInfo); } private void init(Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders, - ZoneInfo zoneInfo) { + ZoneInfo zoneInfo, + SchemaInfo schemaInfo) { startTime = httpRequest.getJDiscRequest().creationTime(TimeUnit.MILLISECONDS); if (queryProfile != null) { // Move all request parameters to the query profile @@ -374,7 +381,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { setPropertiesFromRequestMap(requestMap, properties(), true); // Create the full chain - properties().chain(new RankProfileInputProperties(this)) + properties().chain(new RankProfileInputProperties(schemaInfo, this, embedders)) .chain(new QueryProperties(this, queryProfile.getRegistry(), embedders)) .chain(new ModelObjectMap()) .chain(new RequestContextProperties(requestMap, zoneInfo)) @@ -394,7 +401,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { } else { // bypass these complications if there is no query profile to get values from and validate against properties(). - chain(new RankProfileInputProperties(this)). + chain(new RankProfileInputProperties(schemaInfo, this, embedders)). chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedders)). chain(new PropertyMap()). chain(new DefaultProperties()); @@ -459,32 +466,17 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { } } - private static final List<CompoundName> fieldsToSetFirst = List.of(CompoundName.from("ranking.profile"), - CompoundName.from("model.sources"), - CompoundName.from("model.restrict")); - /** Calls properties.set on all entries in requestMap */ private void setPropertiesFromRequestMap(Map<String, String> requestMap, Properties properties, boolean ignoreSelect) { - // Set these first because they contain type information in inputs which impacts other values set - for (var fieldName : fieldsToSetFirst) - lookupIn(requestMap, fieldName).ifPresent(value -> properties.set(fieldName, value)); - for (var entry : requestMap.entrySet()) { if (ignoreSelect && entry.getKey().equals(Select.SELECT)) continue; + if (RankFeatures.isFeatureName(entry.getKey())) continue; // Set these last properties.set(entry.getKey(), entry.getValue(), requestMap); } - } - - private Optional<String> lookupIn(Map<String, String> requestMap, CompoundName fieldName) { - String value = requestMap.get(fieldName.toString()); - if (value != null) return Optional.of(value); - FieldDescription field = argumentType.getField(fieldName); - for (String alias : field.getAliases()) { - value = requestMap.get(alias); - if (value != null) - return Optional.of(value); + for (var entry : requestMap.entrySet()) { + if ( ! RankFeatures.isFeatureName(entry.getKey())) continue; + properties.set(entry.getKey(), entry.getValue(), requestMap); } - return Optional.empty(); } /** Returns the properties of this query. The properties are modifiable */ @@ -1156,6 +1148,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { private CompiledQueryProfile queryProfile = null; private Map<String, Embedder> embedders = Embedder.throwsOnUse.asMap(); private ZoneInfo zoneInfo = ZoneInfo.defaultInfo(); + private SchemaInfo schemaInfo = SchemaInfo.empty(); public Builder setRequest(String query) { request = HttpRequest.createTestRequest(query, com.yahoo.jdisc.http.HttpRequest.Method.GET); @@ -1218,6 +1211,13 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public ZoneInfo getZoneInfo() { return zoneInfo; } + public Builder setSchemaInfo(SchemaInfo schemaInfo) { + this.schemaInfo = schemaInfo; + return this; + } + + public SchemaInfo getSchemaInfo() { return schemaInfo; } + /** Creates a new query from this builder. No properties are required to before calling this. */ public Query build() { return new Query(this); } diff --git a/container-search/src/main/java/com/yahoo/search/config/RankProfile.java b/container-search/src/main/java/com/yahoo/search/config/RankProfile.java new file mode 100644 index 00000000000..944a23f2964 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/config/RankProfile.java @@ -0,0 +1,94 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.tensor.TensorType; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Information about a rank profile + * + * @author bratseth + */ +public class RankProfile { + + private final String name; + private final boolean hasSummaryFeatures; + private final boolean hasRankFeatures; + private final Map<String, TensorType> inputs; + + private RankProfile(Builder builder) { + this.name = builder.name; + this.hasSummaryFeatures = builder.hasSummaryFeatures; + this.hasRankFeatures = builder.hasRankFeatures; + this.inputs = Map.copyOf(builder.inputs); + } + + public String name() { return name; } + + /** Returns true if this rank profile has summary features. */ + public boolean hasSummaryFeatures() { return hasSummaryFeatures; } + + /** Returns true if this rank profile has rank features. */ + public boolean hasRankFeatures() { return hasRankFeatures; } + + /** Returns the inputs explicitly declared in this rank profile. */ + public Map<String, TensorType> inputs() { return inputs; } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if ( ! (o instanceof RankProfile)) return false; + RankProfile other = (RankProfile)o; + if ( ! other.name.equals(this.name)) return false; + if ( other.hasSummaryFeatures != this.hasSummaryFeatures) return false; + if ( other.hasRankFeatures != this.hasRankFeatures) return false; + if ( ! other.inputs.equals(this.inputs)) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(name, hasSummaryFeatures, hasRankFeatures, inputs); + } + + @Override + public String toString() { + return "rank profile '" + name + "'"; + } + + public static class Builder { + + private final String name; + private boolean hasSummaryFeatures = true; + private boolean hasRankFeatures = true; + private final Map<String, TensorType> inputs = new HashMap<>(); + + public Builder(String name) { + this.name = Objects.requireNonNull(name); + } + + public Builder setHasSummaryFeatures(boolean hasSummaryFeatures) { + this.hasSummaryFeatures = hasSummaryFeatures; + return this; + } + + public Builder setHasRankFeatures(boolean hasRankFeatures) { + this.hasRankFeatures = hasRankFeatures; + return this; + } + + public Builder addInput(String name, TensorType type) { + inputs.put(name, type); + return this; + } + + public RankProfile build() { + return new RankProfile(this); + } + + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/config/Schema.java b/container-search/src/main/java/com/yahoo/search/config/Schema.java new file mode 100644 index 00000000000..57712c731f4 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/config/Schema.java @@ -0,0 +1,71 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.api.annotations.Beta; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Information about a schema which is part of the application running this. + * + * This is immutable. + * + * @author bratseth + */ +@Beta +public class Schema { + + private final String name; + private final Map<String, RankProfile> rankProfiles; + + private Schema(Builder builder) { + this.name = builder.name; + this.rankProfiles = Map.copyOf(builder.rankProfiles); + } + + public String name() { return name; } + public Map<String, RankProfile> rankProfiles() { return rankProfiles; } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if ( ! (o instanceof Schema)) return false; + Schema other = (Schema)o; + if ( ! other.name.equals(this.name)) return false; + if ( ! other.rankProfiles.equals(this.rankProfiles)) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(name, rankProfiles); + } + + @Override + public String toString() { + return "schema '" + name + "'"; + } + + public static class Builder { + + private final String name; + private final Map<String, RankProfile> rankProfiles = new HashMap<>(); + + public Builder(String name) { + this.name = Objects.requireNonNull(name); + } + + public Builder add(RankProfile profile) { + rankProfiles.put(profile.name(), Objects.requireNonNull(profile)); + return this; + } + + public Schema build() { + return new Schema(this); + } + + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/config/SchemaInfo.java b/container-search/src/main/java/com/yahoo/search/config/SchemaInfo.java new file mode 100644 index 00000000000..746f1c340f2 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/config/SchemaInfo.java @@ -0,0 +1,149 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.QrSearchersConfig; +import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; +import com.yahoo.search.Query; +import com.yahoo.tensor.TensorType; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Information about all the schemas configured in the application this container is a part of. + * + * Usage: + * <code> + * SchemaInfo.Session session = schemaInfo.newSession(query); // once when starting to process a query + * session.get(...) // access information about the schema(s) relevant to the query + * </code> + * + * This is immutable. + * + * @author bratseth + */ +// NOTES: +// This should replace IndexFacts, and probably DocumentDatabase. +// It replicates the schema resolution mechanism in IndexFacts, but does not yet contain any field information. +// To replace IndexFacts, this must accept IndexInfo and expose that information, as well as consolidation +// given a set of possible schemas: The session mechanism is present here to make that efficient when added +// (resolving schema subsets for every field lookup is too expensive). +@Beta +public class SchemaInfo { + + private static final SchemaInfo empty = new SchemaInfo(List.of(), Map.of()); + + private final List<Schema> schemas; + + /** The schemas contained in each content cluster indexed by cluster name */ + private final Map<String, List<String>> clusters; + + @Inject + public SchemaInfo(IndexInfoConfig indexInfo, // will be used in the future + DocumentdbInfoConfig documentdbInfoConfig, + QrSearchersConfig qrSearchersConfig) { + this(SchemaInfoConfigurer.toSchemas(documentdbInfoConfig), SchemaInfoConfigurer.toClusters(qrSearchersConfig)); + } + + public SchemaInfo(List<Schema> schemas, Map<String, List<String>> clusters) { + this.schemas = List.copyOf(schemas); + this.clusters = Map.copyOf(clusters); + } + + public Session newSession(Query query) { + return new Session(query.getModel().getSources(), query.getModel().getRestrict(), clusters, schemas); + } + + public static SchemaInfo empty() { return empty; } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if ( ! (o instanceof SchemaInfo)) return false; + SchemaInfo other = (SchemaInfo)o; + if ( ! other.schemas.equals(this.schemas)) return false; + if ( ! other.clusters.equals(this.clusters)) return false; + return true; + } + + @Override + public int hashCode() { return Objects.hash(schemas, clusters); } + + /** The schema information resolved to be relevant to this session. */ + public static class Session { + + private final List<Schema> schemas; + + private Session(Set<String> sources, + Set<String> restrict, + Map<String, List<String>> clusters, + List<Schema> candidates) { + this.schemas = resolveSchemas(sources, restrict, clusters, candidates); + } + + /** + * Given a search list which is a mixture of schemas and cluster + * names, and a restrict list which is a list of schemas, return a + * set of all valid schemas for this combination. + * + * @return the possibly empty list of schemas matching the arguments + */ + private static List<Schema> resolveSchemas(Set<String> sources, + Set<String> restrict, + Map<String, List<String>> clusters, + List<Schema> candidates) { + if (sources.isEmpty()) + return restrict.isEmpty() ? candidates : keep(restrict, candidates); + + Set<String> schemaNames = new HashSet<>(); + for (String source : sources) { + if (clusters.containsKey(source)) // source is a cluster + schemaNames.addAll(clusters.get(source)); + else // source is a schema + schemaNames.add(source); + } + candidates = keep(schemaNames, candidates); + return restrict.isEmpty() ? candidates : keep(restrict, candidates); + } + + private static List<Schema> keep(Set<String> names, List<Schema> schemas) { + return schemas.stream().filter(schema -> names.contains(schema.name())).collect(Collectors.toList()); + } + + /** + * Returns the type of the given rank feature name in the given profile, + * if it can be uniquely determined. + * + * @param rankFeature the rank feature name, a string on the form "query(name)" + * @param rankProfile the name of the rank profile in which to locate the input declaration + * @return the type of the declared input, or null if it is not declared or the rank profile is not found + * @throws IllegalArgumentException if the feature is declared in this rank profile in multiple schemas + * of this session with conflicting types + */ + public TensorType rankProfileInput(String rankFeature, String rankProfile) { + TensorType foundType = null; + Schema declaringSchema = null; + for (Schema schema : schemas) { + RankProfile profile = schema.rankProfiles().get(rankProfile); + if (profile == null) continue; + TensorType newlyFoundType = profile.inputs().get(rankFeature); + if (newlyFoundType == null) continue; + if (foundType != null && ! newlyFoundType.equals(foundType)) + throw new IllegalArgumentException("Conflicting input type declarations for '" + rankFeature + "': " + + "Declared as " + foundType + " in " + profile + " in " + declaringSchema + + ", and as " + newlyFoundType + " in " + profile + " in " + schema); + foundType = newlyFoundType; + declaringSchema = schema; + } + return foundType; + } + + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/config/SchemaInfoConfigurer.java b/container-search/src/main/java/com/yahoo/search/config/SchemaInfoConfigurer.java new file mode 100644 index 00000000000..ae06babda66 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/config/SchemaInfoConfigurer.java @@ -0,0 +1,51 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.container.QrSearchersConfig; +import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Translation between schema info configuration and schema objects. + * + * @author bratseth + */ +class SchemaInfoConfigurer { + + static List<Schema> toSchemas(DocumentdbInfoConfig documentdbInfoConfig) { + return documentdbInfoConfig.documentdb().stream().map(config -> toSchema(config)).collect(Collectors.toList()); + } + + static Schema toSchema(DocumentdbInfoConfig.Documentdb documentDbConfig) { + Schema.Builder builder = new Schema.Builder(documentDbConfig.name()); + for (var profileConfig : documentDbConfig.rankprofile()) { + RankProfile.Builder profileBuilder = new RankProfile.Builder(profileConfig.name()); + profileBuilder.setHasSummaryFeatures(profileConfig.hasSummaryFeatures()); + profileBuilder.setHasRankFeatures(profileConfig.hasRankFeatures()); + for (var inputConfig : profileConfig.input()) + profileBuilder.addInput(inputConfig.name(), TensorType.fromSpec(inputConfig.type())); + builder.add(profileBuilder.build()); + } + return builder.build(); + } + + static Map<String, List<String>> toClusters(QrSearchersConfig config) { + Map<String, List<String>> clusters = new HashMap<>(); + for (int i = 0; i < config.searchcluster().size(); ++i) { + List<String> schemas = new ArrayList<>(); + String clusterName = config.searchcluster(i).name(); + for (int j = 0; j < config.searchcluster(i).searchdef().size(); ++j) + schemas.add(config.searchcluster(i).searchdef(j)); + clusters.put(clusterName, schemas); + } + return clusters; + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/config/dispatchprototype/package-info.java b/container-search/src/main/java/com/yahoo/search/config/dispatchprototype/package-info.java deleted file mode 100644 index 217fbe80888..00000000000 --- a/container-search/src/main/java/com/yahoo/search/config/dispatchprototype/package-info.java +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * Package for dispatchprototype config. - * @author Tony Vaagenes - */ -@ExportPackage -package com.yahoo.search.config.dispatchprototype; - -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/com/yahoo/search/config/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/config/internal/TensorConverter.java new file mode 100644 index 00000000000..fbe2ffb8984 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/config/internal/TensorConverter.java @@ -0,0 +1,95 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config.internal; + +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * A class which knows how to convert an Object value to a tensor of a given type. + * + * @author bratseth + */ +public class TensorConverter { + + private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + + private final Map<String, Embedder> embedders; + + public TensorConverter(Map<String, Embedder> embedders) { + this.embedders = embedders; + } + + public Tensor convertTo(TensorType type, String key, Object value, Language language) { + var context = new Embedder.Context(key).setLanguage(language); + Tensor tensor = toTensor(type, value, context); + if (tensor == null) return null; + if (! tensor.type().isAssignableTo(type)) + throw new IllegalArgumentException("Require a tensor of type " + type); + return tensor; + } + + private Tensor toTensor(TensorType type, Object value, Embedder.Context context) { + if (value instanceof Tensor) return (Tensor)value; + if (value instanceof String && isEmbed((String)value)) return embed((String)value, type, context); + if (value instanceof String) return Tensor.from(type, (String)value); + return null; + } + + static boolean isEmbed(String value) { + return value.startsWith("embed("); + } + + private Tensor embed(String s, TensorType type, Embedder.Context embedderContext) { + if ( ! s.endsWith(")")) + throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); + String argument = s.substring("embed(".length(), s.length() - 1); + Embedder embedder; + + // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") + Matcher matcher = embedderArgumentRegexp.matcher(argument); + if (matcher.matches()) { + String embedderId = matcher.group(1); + argument = matcher.group(2); + if ( ! embedders.containsKey(embedderId)) { + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + } + embedder = embedders.get(embedderId); + } else if (embedders.size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } else if (embedders.size() > 1) { + throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(embedders)); + } else { + embedder = embedders.entrySet().stream().findFirst().get().getValue(); + } + + return embedder.embed(removeQuotes(argument), embedderContext, type); + } + + private static String removeQuotes(String s) { + if (s.startsWith("'") && s.endsWith("'")) { + return s.substring(1, s.length() - 1); + } + if (s.startsWith("\"") && s.endsWith("\"")) { + return s.substring(1, s.length() - 1); + } + return s; + } + + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/config/package-info.java b/container-search/src/main/java/com/yahoo/search/config/package-info.java index aec055a0972..dd9c7bfcf04 100644 --- a/container-search/src/main/java/com/yahoo/search/config/package-info.java +++ b/container-search/src/main/java/com/yahoo/search/config/package-info.java @@ -1,5 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage +@PublicApi package com.yahoo.search.config; +import com.yahoo.api.annotations.PublicApi; import com.yahoo.osgi.annotation.ExportPackage; + +/** + * Information about the current configuration this is running as a part of. + */
\ No newline at end of file diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java index af6374ba245..0b9885e6bb7 100644 --- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java +++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java @@ -360,6 +360,7 @@ public class SearchHandler extends LoggingRequestHandler { .setQueryProfile(queryProfile) .setEmbedders(embedders) .setZoneInfo(zoneInfo) + .setSchemaInfo(executionFactory.schemaInfo()) .build(); boolean benchmarking = VespaHeaders.benchmarkOutput(request); diff --git a/container-search/src/main/java/com/yahoo/search/query/Model.java b/container-search/src/main/java/com/yahoo/search/query/Model.java index 0cc7a50141f..82cda9e8a1b 100644 --- a/container-search/src/main/java/com/yahoo/search/query/Model.java +++ b/container-search/src/main/java/com/yahoo/search/query/Model.java @@ -10,6 +10,7 @@ import com.yahoo.prelude.query.TaggableItem; import com.yahoo.processing.IllegalInputException; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; +import com.yahoo.search.config.SchemaInfo; import com.yahoo.search.query.parser.Parsable; import com.yahoo.search.query.parser.Parser; import com.yahoo.search.query.parser.ParserEnvironment; @@ -91,8 +92,13 @@ public class Model implements Cloneable { private Set<String> restrict = new LinkedHashSet<>(); private String searchPath; private String documentDbName = null; - private Execution execution = new Execution(new Execution.Context(null, null, - null, null, null, Runnable::run)); + private Execution execution = new Execution(new Execution.Context(null, + null, + SchemaInfo.empty(), + null, + null, + null, + Runnable::run)); public Model(Query query) { setParent(query); diff --git a/container-search/src/main/java/com/yahoo/search/query/Ranking.java b/container-search/src/main/java/com/yahoo/search/query/Ranking.java index 94b8b5f63f1..fd0cd5a85b7 100644 --- a/container-search/src/main/java/com/yahoo/search/query/Ranking.java +++ b/container-search/src/main/java/com/yahoo/search/query/Ranking.java @@ -52,18 +52,14 @@ public class Ranking implements Cloneable { public static final String FEATURES = "features"; public static final String PROPERTIES = "properties"; - /** For internal use only. */ - public static Optional<String> lookupRankProfileIn(Map<String, String> properties) { - return Optional.ofNullable(Optional.ofNullable(properties.get(RANKING + "." + PROFILE)) - .orElse(properties.get("ranking"))); - } - static { argumentType = new QueryProfileType(RANKING); argumentType.setStrict(true); argumentType.setBuiltin(true); + // Note: Order here matters as fields are set in this order, and rank feature conversion depends + // on other fields already being set (see RankProfileInputProperties) + argumentType.addField(new FieldDescription(PROFILE, "string", "ranking")); argumentType.addField(new FieldDescription(LOCATION, "string", "location")); - argumentType.addField(new FieldDescription(PROFILE, "string", "ranking")); // Alias repeated in lookupRankProfileIn argumentType.addField(new FieldDescription(SORTING, "string", "sorting sortspec")); argumentType.addField(new FieldDescription(LIST_FEATURES, "string", RANKFEATURES.toString())); argumentType.addField(new FieldDescription(FRESHNESS, "string", "datetime")); @@ -73,7 +69,7 @@ public class Ranking implements Cloneable { argumentType.addField(new FieldDescription(DIVERSITY, new QueryProfileFieldType(Diversity.getArgumentType()))); argumentType.addField(new FieldDescription(SOFTTIMEOUT, new QueryProfileFieldType(SoftTimeout.getArgumentType()))); argumentType.addField(new FieldDescription(MATCHING, new QueryProfileFieldType(Matching.getArgumentType()))); - argumentType.addField(new FieldDescription(FEATURES, "query-profile", "rankfeature")); + argumentType.addField(new FieldDescription(FEATURES, "query-profile", "rankfeature")); // Repeated at the end of RankFeatures argumentType.addField(new FieldDescription(PROPERTIES, "query-profile", "rankproperty")); argumentType.freeze(); argumentTypeName = new CompoundName(argumentType.getId().getName()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java index 165ec460822..02a4199d32e 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/QueryProfileType.java @@ -12,6 +12,7 @@ import com.yahoo.search.query.profile.QueryProfile; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -48,7 +49,7 @@ public class QueryProfileType extends FreezableSimpleComponent { } public QueryProfileType(ComponentId id) { - this(id, new HashMap<>(), new ArrayList<>()); + this(id, new LinkedHashMap<>(), new ArrayList<>()); } private QueryProfileType(ComponentId id, Map<String, FieldDescription> fields, List<QueryProfileType> inherited) { @@ -61,7 +62,7 @@ public class QueryProfileType extends FreezableSimpleComponent { private QueryProfileType(ComponentId id, Map<String, FieldDescription> fields, List<QueryProfileType> inherited, boolean strict, boolean matchAsPath, boolean builtin, Map<String,String> aliases) { - this(id, new HashMap<>(fields), new ArrayList<>(inherited)); + this(id, new LinkedHashMap<>(fields), new ArrayList<>(inherited)); this.strict = strict; this.matchAsPath = matchAsPath; this.builtin = builtin; @@ -79,7 +80,7 @@ public class QueryProfileType extends FreezableSimpleComponent { } // Unfreeze nested query profile references - Map<String, FieldDescription> unfrozenFields = new HashMap<>(); + Map<String, FieldDescription> unfrozenFields = new LinkedHashMap<>(); for (Map.Entry<String, FieldDescription> field : fields.entrySet()) { FieldDescription unfrozenFieldValue = field.getValue(); if (field.getValue().getType() instanceof QueryProfileFieldType) { @@ -196,8 +197,8 @@ public class QueryProfileType extends FreezableSimpleComponent { * Default: true (so all non-declared fields returns true) */ public boolean isOverridable(String fieldName) { - FieldDescription field=getField(fieldName); - if (field==null) return true; + FieldDescription field = getField(fieldName); + if (field == null) return true; return field.isOverridable(); } @@ -208,8 +209,8 @@ public class QueryProfileType extends FreezableSimpleComponent { * null if no types are legal (i.e if the name is not legal) */ public Class<?> getValueClass(String name) { - FieldDescription fieldDescription=getField(name); - if (fieldDescription==null) { + FieldDescription fieldDescription = getField(name); + if (fieldDescription == null) { if (strict) return null; // Undefined -> Not legal else @@ -335,7 +336,7 @@ public class QueryProfileType extends FreezableSimpleComponent { // found in registry but not already added in *this* type (getField also checks parents): extend it if (type != null && ! fields.containsKey(name)) { type = new QueryProfileType(registry.createAnonymousId(type.getIdString()), - new HashMap<>(), + new LinkedHashMap<>(), List.of(type)); } @@ -368,7 +369,7 @@ public class QueryProfileType extends FreezableSimpleComponent { if (inherited().size() == 0) return Collections.unmodifiableMap(fields); // Collapse inherited - Map<String, FieldDescription> allFields = new HashMap<>(); + Map<String, FieldDescription> allFields = new LinkedHashMap<>(); for (QueryProfileType inheritedType : inherited) allFields.putAll(inheritedType.fields()); allFields.putAll(fields); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index cc6b18af820..e0dea744075 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -1,18 +1,14 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query.profile.types; -import com.yahoo.language.process.Embedder; import com.yahoo.processing.request.Properties; +import com.yahoo.search.config.internal.TensorConverter; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.query.profile.SubstituteString; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.ArrayList; -import java.util.List; import java.util.Map; -import java.util.regex.Matcher; -import java.util.regex.Pattern; /** * A tensor field type in a query profile @@ -21,8 +17,6 @@ import java.util.regex.Pattern; */ public class TensorFieldType extends FieldType { - private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); - private final TensorType type; /** Creates a tensor field type with information about the kind of tensor this will hold */ @@ -54,71 +48,7 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - Tensor tensor = toTensor(o, context); - if (tensor == null) return null; - if (! tensor.type().isAssignableTo(type)) - throw new IllegalArgumentException("Require a tensor of type " + type); - return tensor; - } - - private Tensor toTensor(Object o, ConversionContext context) { - if (o instanceof Tensor) return (Tensor)o; - if (o instanceof String && isEmbed((String)o)) return embed((String)o, type, context); - if (o instanceof String) return Tensor.from(type, (String)o); - return null; - } - - static boolean isEmbed(String value) { - return value.startsWith("embed("); - } - - static Tensor embed(String s, TensorType type, ConversionContext context) { - if ( ! s.endsWith(")")) - throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); - String argument = s.substring("embed(".length(), s.length() - 1); - Embedder embedder; - - // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher = embedderArgumentRegexp.matcher(argument); - if (matcher.matches()) { - String embedderId = matcher.group(1); - argument = matcher.group(2); - if (!context.embedders().containsKey(embedderId)) { - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(context.embedders())); - } - embedder = context.embedders().get(embedderId); - } else if (context.embedders().size() == 0) { - throw new IllegalStateException("No embedders provided"); // should never happen - } else if (context.embedders().size() > 1) { - throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + - "Valid embedders are " + validEmbedders(context.embedders())); - } else { - embedder = context.embedders().entrySet().stream().findFirst().get().getValue(); - } - - return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type); - } - - private static String removeQuotes(String s) { - if (s.startsWith("'") && s.endsWith("'")) { - return s.substring(1, s.length() - 1); - } - if (s.startsWith("\"") && s.endsWith("\"")) { - return s.substring(1, s.length() - 1); - } - return s; - } - - private static String validEmbedders(Map<String, Embedder> embedders) { - List<String> embedderIds = new ArrayList<>(); - embedders.forEach((key, value) -> embedderIds.add(key)); - embedderIds.sort(null); - return String.join(",", embedderIds); - } - - private static Embedder.Context toEmbedderContext(ConversionContext context) { - return new Embedder.Context(context.destination()).setLanguage(context.language()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index 6769f05bb3e..7f4cee07e8c 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -2,9 +2,16 @@ package com.yahoo.search.query.properties; import com.yahoo.api.annotations.Beta; +import com.yahoo.language.process.Embedder; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; +import com.yahoo.search.config.SchemaInfo; +import com.yahoo.search.config.internal.TensorConverter; import com.yahoo.search.query.Properties; +import com.yahoo.search.query.Ranking; +import com.yahoo.search.query.ranking.RankFeatures; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.Map; @@ -16,19 +23,71 @@ import java.util.Map; @Beta public class RankProfileInputProperties extends Properties { + private final SchemaInfo schemaInfo; private final Query query; + private final TensorConverter tensorConverter; - public RankProfileInputProperties(Query query) { + private SchemaInfo.Session session = null; + + public RankProfileInputProperties(SchemaInfo schemaInfo, Query query, Map<String, Embedder> embedders) { + this.schemaInfo = schemaInfo; this.query = query; + this.tensorConverter = new TensorConverter(embedders); + } + + @Override + public void set(CompoundName name, Object value, Map<String, String> context) { + if (RankFeatures.isFeatureName(name.toString())) { + TensorType expectedType = typeOf(name); + if (expectedType != null) { + try { + value = tensorConverter.convertTo(expectedType, + name.last(), + value, + query.getModel().getLanguage()); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "'", e); + } + } + } + super.set(name, value, context); } - /** - * Throws IllegalInputException if the given key cannot be set to the given value. - * This default implementation just passes to the chained properties, if any. - */ + @Override public void requireSettable(CompoundName name, Object value, Map<String, String> context) { - if (chained() != null) - chained().requireSettable(name, value, context); + if (RankFeatures.isFeatureName(name.toString())) { + TensorType expectedType = typeOf(name); + if (expectedType != null) + verifyType(name, value, expectedType); + } + super.requireSettable(name, value, context); + } + + private TensorType typeOf(CompoundName name) { + // Session is lazily resolved because order matters: + // model.sources+restrict must be set in the query before this is done + if (session == null) + session = schemaInfo.newSession(query); + // In addition, the rank profile must be set before features + return session.rankProfileInput(name.last(), query.getRanking().getProfile()); + } + + private void verifyType(CompoundName name, Object value, TensorType expectedType) { + if (value instanceof Tensor) { + TensorType valueType = ((Tensor)value).type(); + if ( ! valueType.isAssignableTo(expectedType)) + throwIllegalInput(name, value, expectedType); + } + else if (expectedType.rank() > 0) { // rank 0 tensor may also be represented as a scalar or string + throwIllegalInput(name, value, expectedType); + } + } + + private void throwIllegalInput(CompoundName name, Object value, TensorType expectedType) { + throw new IllegalArgumentException("Could not set '" + name + "' to '" + value + "': " + + "This input is declared in rank profile '" + query.getRanking().getProfile() + + "' as " + expectedType); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java index ff2e3949ec4..dab824a6fef 100644 --- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java +++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java @@ -191,4 +191,8 @@ public class RankFeatures implements Cloneable { return JSON.encode(features); } + public static boolean isFeatureName(String fullPropertyName) { + return fullPropertyName.startsWith("ranking.features.") || fullPropertyName.startsWith("rankfeature."); + } + } diff --git a/container-search/src/main/java/com/yahoo/search/searchchain/Execution.java b/container-search/src/main/java/com/yahoo/search/searchchain/Execution.java index 9374027504e..baf9f35c72b 100644 --- a/container-search/src/main/java/com/yahoo/search/searchchain/Execution.java +++ b/container-search/src/main/java/com/yahoo/search/searchchain/Execution.java @@ -16,6 +16,7 @@ import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; import com.yahoo.search.cluster.PingableSearcher; +import com.yahoo.search.config.SchemaInfo; import com.yahoo.search.rendering.Renderer; import com.yahoo.search.rendering.RendererRegistry; import com.yahoo.search.statistics.TimeTracker; @@ -78,6 +79,8 @@ public class Execution extends com.yahoo.processing.execution.Execution { private IndexFacts indexFacts = null; + private SchemaInfo schemaInfo = SchemaInfo.empty(); + /** The current set of special tokens */ private SpecialTokenRegistry tokenRegistry = null; @@ -116,7 +119,7 @@ public class Execution extends com.yahoo.processing.execution.Execution { * This context is never attached to an execution but is used to carry state into * another context. */ - public Context(SearchChainRegistry searchChainRegistry, IndexFacts indexFacts, + public Context(SearchChainRegistry searchChainRegistry, IndexFacts indexFacts, SchemaInfo schemaInfo, SpecialTokenRegistry tokenRegistry, RendererRegistry rendererRegistry, Linguistics linguistics, Executor executor) { owner = null; @@ -125,50 +128,67 @@ public class Execution extends com.yahoo.processing.execution.Execution { // obviously, the most complete constructor. this.searchChainRegistry = searchChainRegistry; this.indexFacts = indexFacts; + this.schemaInfo = Objects.requireNonNull(schemaInfo); this.tokenRegistry = tokenRegistry; this.rendererRegistry = rendererRegistry; this.linguistics = linguistics; this.executor = Objects.requireNonNull(executor, "The executor cannot be null"); } + /** @deprecated pass schemaInfo */ + @Deprecated + public Context(SearchChainRegistry searchChainRegistry, IndexFacts indexFacts, + SpecialTokenRegistry tokenRegistry, RendererRegistry rendererRegistry, Linguistics linguistics, + Executor executor) { + this(searchChainRegistry, indexFacts, SchemaInfo.empty(), tokenRegistry, rendererRegistry, linguistics, Runnable::run); + } + /** @deprecated pass an executor */ @Deprecated // TODO: Remove on Vespa 8 public Context(SearchChainRegistry searchChainRegistry, IndexFacts indexFacts, SpecialTokenRegistry tokenRegistry, RendererRegistry rendererRegistry, Linguistics linguistics) { - this(searchChainRegistry, indexFacts, tokenRegistry, rendererRegistry, linguistics, Runnable::run); + this(searchChainRegistry, indexFacts, SchemaInfo.empty(), tokenRegistry, rendererRegistry, linguistics, Runnable::run); } /** Creates a Context instance where everything except the given arguments is empty. This is for unit testing.*/ public static Context createContextStub() { - return createContextStub(null, null, null); + return createContextStub(null, null, SchemaInfo.empty(), null); } /** Creates a Context instance where everything except the given arguments is empty. This is for unit testing.*/ public static Context createContextStub(SearchChainRegistry searchChainRegistry) { - return createContextStub(searchChainRegistry, null, null); + return createContextStub(searchChainRegistry, null, SchemaInfo.empty(), null); } /** Creates a Context instance where everything except the given arguments is empty. This is for unit testing.*/ public static Context createContextStub(IndexFacts indexFacts) { - return createContextStub(null, indexFacts, null); + return createContextStub(null, indexFacts, SchemaInfo.empty(), null); } /** Creates a Context instance where everything except the given arguments is empty. This is for unit testing.*/ public static Context createContextStub(SearchChainRegistry searchChainRegistry, IndexFacts indexFacts) { - return createContextStub(searchChainRegistry, indexFacts, null); + return createContextStub(searchChainRegistry, indexFacts, SchemaInfo.empty(), null); } /** Creates a Context instance where everything except the given arguments is empty. This is for unit testing.*/ public static Context createContextStub(IndexFacts indexFacts, Linguistics linguistics) { - return createContextStub(null, indexFacts, linguistics); + return createContextStub(null, indexFacts, SchemaInfo.empty(), linguistics); + } + + public static Context createContextStub(SearchChainRegistry searchChainRegistry, + IndexFacts indexFacts, + Linguistics linguistics) { + return createContextStub(searchChainRegistry, indexFacts, SchemaInfo.empty(), linguistics); } /** Creates a Context instance where everything except the given arguments is empty. This is for unit testing.*/ public static Context createContextStub(SearchChainRegistry searchChainRegistry, IndexFacts indexFacts, + SchemaInfo schemaInfo, Linguistics linguistics) { return new Context(searchChainRegistry != null ? searchChainRegistry : new SearchChainRegistry(), indexFacts != null ? indexFacts : new IndexFacts(), + schemaInfo, null, new RendererRegistry(Runnable::run), linguistics != null ? linguistics : new SimpleLinguistics(), @@ -188,6 +208,7 @@ public class Execution extends com.yahoo.processing.execution.Execution { breakdown = sourceContext.breakdown; if (indexFacts == null) indexFacts = sourceContext.indexFacts; + schemaInfo = sourceContext.schemaInfo; if (tokenRegistry == null) tokenRegistry = sourceContext.tokenRegistry; if (searchChainRegistry == null) @@ -207,6 +228,7 @@ public class Execution extends com.yahoo.processing.execution.Execution { void fill(Context other) { searchChainRegistry = other.searchChainRegistry; indexFacts = other.indexFacts; + schemaInfo = other.schemaInfo; tokenRegistry = other.tokenRegistry; rendererRegistry = other.rendererRegistry; detailedDiagnostics = other.detailedDiagnostics; @@ -219,18 +241,20 @@ public class Execution extends com.yahoo.processing.execution.Execution { // equals() needs to be cheap, that's yet another reason we can only // allow immutables and frozen objects in the context return other.indexFacts == indexFacts - && other.rendererRegistry == rendererRegistry - && other.tokenRegistry == tokenRegistry - && other.searchChainRegistry == searchChainRegistry - && other.detailedDiagnostics == detailedDiagnostics - && other.breakdown == breakdown - && other.linguistics == linguistics - && other.executor == executor; + && other.schemaInfo == schemaInfo + && other.rendererRegistry == rendererRegistry + && other.tokenRegistry == tokenRegistry + && other.searchChainRegistry == searchChainRegistry + && other.detailedDiagnostics == detailedDiagnostics + && other.breakdown == breakdown + && other.linguistics == linguistics + && other.executor == executor; } @Override public int hashCode() { return java.util.Objects.hash(indexFacts, + schemaInfo, rendererRegistry, tokenRegistry, searchChainRegistry, detailedDiagnostics, breakdown, linguistics, @@ -293,28 +317,25 @@ public class Execution extends com.yahoo.processing.execution.Execution { this.indexFacts = indexFacts; } + /** Returns information about the schemas specified in this application. This is never null. */ + public SchemaInfo schemaInfo() { return schemaInfo; } + /** * Returns the search chain registry to use with this execution. This is * a snapshot taken at creation of this execution, use * Context.shallowCopy() to get a correctly instantiated Context if * making a custom Context instance. */ - public SearchChainRegistry searchChainRegistry() { - return searchChainRegistry; - } + public SearchChainRegistry searchChainRegistry() { return searchChainRegistry; } /** * Returns the template registry to use with this execution. This is * a snapshot taken at creation of this execution. */ - public RendererRegistry rendererRegistry() { - return rendererRegistry; - } + public RendererRegistry rendererRegistry() { return rendererRegistry; } /** Returns the current set of special strings for the query tokenizer */ - public SpecialTokenRegistry getTokenRegistry() { - return tokenRegistry; - } + public SpecialTokenRegistry getTokenRegistry() { return tokenRegistry; } /** * Wrapping the incoming special token registry and then setting the @@ -324,13 +345,9 @@ public class Execution extends com.yahoo.processing.execution.Execution { * * @param tokenRegistry a new registry for overriding behavior of following searchers */ - public void setTokenRegistry(SpecialTokenRegistry tokenRegistry) { - this.tokenRegistry = tokenRegistry; - } + public void setTokenRegistry(SpecialTokenRegistry tokenRegistry) { this.tokenRegistry = tokenRegistry; } - public void setDetailedDiagnostics(boolean breakdown) { - this.detailedDiagnostics = breakdown; - } + public void setDetailedDiagnostics(boolean breakdown) { this.detailedDiagnostics = breakdown; } /** * The container has some internal diagnostics mechanisms which may be @@ -342,9 +359,7 @@ public class Execution extends com.yahoo.processing.execution.Execution { * @return whether components exposing different level of diagnostics * should go for the most detailed level */ - public boolean getDetailedDiagnostics() { - return detailedDiagnostics; - } + public boolean getDetailedDiagnostics() { return detailedDiagnostics; } /** * If too many queries time out, the search handler will assume the @@ -352,27 +367,17 @@ public class Execution extends com.yahoo.processing.execution.Execution { * * @return whether the system is assumed to be in a breakdown state */ - public boolean getBreakdown() { - return breakdown; - } + public boolean getBreakdown() { return breakdown; } - public void setBreakdown(boolean breakdown) { - this.breakdown = breakdown; - } + public void setBreakdown(boolean breakdown) { this.breakdown = breakdown; } /** * Returns the {@link Linguistics} object assigned to this Context. This object provides access to all the * linguistic-related APIs, and comes pre-configured with the Execution given. - * - * @return The current Linguistics. */ - public Linguistics getLinguistics() { - return linguistics; - } + public Linguistics getLinguistics() { return linguistics; } - public void setLinguistics(Linguistics linguistics) { - this.linguistics = linguistics; - } + public void setLinguistics(Linguistics linguistics) { this.linguistics = linguistics; } /** * Returns the executor that should be used to execute tasks as part of this execution. @@ -474,15 +479,10 @@ public class Execution extends com.yahoo.processing.execution.Execution { * to ensure only searchChain or searcher is null (and because it's long and * cumbersome). * - * @param searchChain - * the search chain to execute, must be null if searcher is set - * @param context - * execution context for the search - * @param searcherIndex - * index of the first searcher to invoke, see - * Execution(Execution) - * @throws IllegalArgumentException - * if searchChain is null + * @param searchChain the search chain to execute, must be null if searcher is set + * @param context execution context for the search + * @param searcherIndex index of the first searcher to invoke, see Execution(Execution) + * @throws IllegalArgumentException if searchChain is null */ @SuppressWarnings("unchecked") private Execution(Chain<? extends Processor> searchChain, Context context, int searcherIndex) { @@ -493,7 +493,7 @@ public class Execution extends com.yahoo.processing.execution.Execution { super(searchChain, searcherIndex, context.createChildTrace(), context.createChildEnvironment()); this.context.fill(context); contextCache = new Context[searchChain.components().size()]; - entryIndex=searcherIndex; + entryIndex = searcherIndex; timer = new TimeTracker(searchChain, searcherIndex); } diff --git a/container-search/src/main/java/com/yahoo/search/searchchain/ExecutionFactory.java b/container-search/src/main/java/com/yahoo/search/searchchain/ExecutionFactory.java index 28c8ed8f3cf..06814a4c436 100644 --- a/container-search/src/main/java/com/yahoo/search/searchchain/ExecutionFactory.java +++ b/container-search/src/main/java/com/yahoo/search/searchchain/ExecutionFactory.java @@ -16,9 +16,11 @@ import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.prelude.IndexFacts; import com.yahoo.prelude.IndexModel; import com.yahoo.language.process.SpecialTokenRegistry; +import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; import com.yahoo.processing.rendering.Renderer; import com.yahoo.search.Searcher; import com.yahoo.search.config.IndexInfoConfig; +import com.yahoo.search.config.SchemaInfo; import com.yahoo.search.rendering.RendererRegistry; import com.yahoo.vespa.configdefinition.SpecialtokensConfig; @@ -39,24 +41,17 @@ public class ExecutionFactory extends AbstractComponent { private final SearchChainRegistry searchChainRegistry; private final IndexFacts indexFacts; + private final SchemaInfo schemaInfo; private final SpecialTokenRegistry specialTokens; private final Linguistics linguistics; private final ThreadPoolExecutor renderingExecutor; private final RendererRegistry rendererRegistry; private final Executor executor; - private static ThreadPoolExecutor createRenderingExecutor() { - int threadCount = Runtime.getRuntime().availableProcessors(); - ThreadPoolExecutor executor = new ThreadPoolExecutor(threadCount, threadCount, 1L, TimeUnit.SECONDS, - new LinkedBlockingQueue<>(), - ThreadFactoryFactory.getThreadFactory("common-rendering")); - executor.prestartAllCoreThreads(); - return executor; - } - @Inject public ExecutionFactory(ChainsConfig chainsConfig, IndexInfoConfig indexInfo, + DocumentdbInfoConfig documentdbInfo, QrSearchersConfig clusters, ComponentRegistry<Searcher> searchers, SpecialtokensConfig specialTokens, @@ -65,6 +60,7 @@ public class ExecutionFactory extends AbstractComponent { Executor executor) { this.searchChainRegistry = createSearchChainRegistry(searchers, chainsConfig); this.indexFacts = new IndexFacts(new IndexModel(indexInfo, clusters)).freeze(); + this.schemaInfo = new SchemaInfo(indexInfo, documentdbInfo, clusters); this.specialTokens = new SpecialTokenRegistry(specialTokens); this.linguistics = linguistics; this.renderingExecutor = createRenderingExecutor(); @@ -72,6 +68,19 @@ public class ExecutionFactory extends AbstractComponent { this.executor = executor != null ? executor : Executors.newSingleThreadExecutor(); } + /** @deprecated pass documentDbInfo */ + @Deprecated + public ExecutionFactory(ChainsConfig chainsConfig, + IndexInfoConfig indexInfo, + QrSearchersConfig clusters, + ComponentRegistry<Searcher> searchers, + SpecialtokensConfig specialTokens, + Linguistics linguistics, + ComponentRegistry<Renderer> renderers, + Executor executor) { + this(chainsConfig, indexInfo, new DocumentdbInfoConfig.Builder().build(), clusters, searchers, specialTokens, linguistics, renderers, executor); + } + /** @deprecated pass the container threadpool */ @Deprecated // TODO: Remove on Vespa 8 public ExecutionFactory(ChainsConfig chainsConfig, @@ -81,10 +90,11 @@ public class ExecutionFactory extends AbstractComponent { SpecialtokensConfig specialTokens, Linguistics linguistics, ComponentRegistry<Renderer> renderers) { - this(chainsConfig, indexInfo, clusters, searchers, specialTokens, linguistics, renderers, null); + this(chainsConfig, indexInfo, new DocumentdbInfoConfig.Builder().build(), clusters, searchers, specialTokens, linguistics, renderers, null); } - private SearchChainRegistry createSearchChainRegistry(ComponentRegistry<Searcher> searchers, ChainsConfig chainsConfig) { + private SearchChainRegistry createSearchChainRegistry(ComponentRegistry<Searcher> searchers, + ChainsConfig chainsConfig) { SearchChainRegistry searchChainRegistry = new SearchChainRegistry(searchers); ChainsModel chainsModel = ChainsModelBuilder.buildFromConfig(chainsConfig); ChainsConfigurer.prepareChainRegistry(searchChainRegistry, chainsModel, searchers); @@ -98,7 +108,7 @@ public class ExecutionFactory extends AbstractComponent { */ public Execution newExecution(Chain<? extends Searcher> searchChain) { return new Execution(searchChain, - new Execution.Context(searchChainRegistry, indexFacts, specialTokens, rendererRegistry, linguistics, executor)); + new Execution.Context(searchChainRegistry, indexFacts, schemaInfo, specialTokens, rendererRegistry, linguistics, executor)); } /** @@ -107,7 +117,7 @@ public class ExecutionFactory extends AbstractComponent { */ public Execution newExecution(String searchChainId) { return new Execution(searchChainRegistry().getChain(searchChainId), - new Execution.Context(searchChainRegistry, indexFacts, specialTokens, rendererRegistry, linguistics, executor)); + new Execution.Context(searchChainRegistry, indexFacts, schemaInfo, specialTokens, rendererRegistry, linguistics, executor)); } /** Returns the search chain registry used by this */ @@ -116,6 +126,8 @@ public class ExecutionFactory extends AbstractComponent { /** Returns the renderers known to this */ public RendererRegistry rendererRegistry() { return rendererRegistry; } + public SchemaInfo schemaInfo() { return schemaInfo; } + @Override public void deconstruct() { rendererRegistry.deconstruct(); @@ -132,6 +144,7 @@ public class ExecutionFactory extends AbstractComponent { public static ExecutionFactory empty() { return new ExecutionFactory(new ChainsConfig.Builder().build(), new IndexInfoConfig.Builder().build(), + new DocumentdbInfoConfig.Builder().build(), new QrSearchersConfig.Builder().build(), new ComponentRegistry<>(), new SpecialtokensConfig.Builder().build(), @@ -140,4 +153,13 @@ public class ExecutionFactory extends AbstractComponent { null); } + private static ThreadPoolExecutor createRenderingExecutor() { + int threadCount = Runtime.getRuntime().availableProcessors(); + ThreadPoolExecutor executor = new ThreadPoolExecutor(threadCount, threadCount, 1L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(), + ThreadFactoryFactory.getThreadFactory("common-rendering")); + executor.prestartAllCoreThreads(); + return executor; + } + } diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index 94b7c140b0f..25d0e184588 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -2,8 +2,6 @@ package com.yahoo.search.searchers; -import com.yahoo.api.annotations.Beta; - import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.NearestNeighborItem; import com.yahoo.prelude.query.ToolBox; @@ -38,7 +36,7 @@ public class ValidateNearestNeighborSearcher extends Searcher { public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) { for (AttributesConfig.Attribute a : attributesConfig.attribute()) { if (! validAttributes.containsKey(a.name())) { - validAttributes.put(a.name(), new ArrayList<TensorType>()); + validAttributes.put(a.name(), new ArrayList<>()); } if (a.datatype() == AttributesConfig.Attribute.Datatype.TENSOR) { TensorType tt = TensorType.fromSpec(a.tensortype()); diff --git a/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTest.java b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTest.java new file mode 100644 index 00000000000..728ebbf8f7f --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTest.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class SchemaInfoTest { + + @Test + public void testSchemaInfoConfiguration() { + assertEquals(SchemaInfoTester.createSchemaInfoFromConfig(), SchemaInfoTester.createSchemaInfo()); + } + + @Test + public void testInputResolution() { + var tester = new SchemaInfoTester(); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "", "", "commonProfile", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "", "commonProfile", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "a", "", "commonProfile", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "b", "", "commonProfile", "query(myTensor1)"); + + tester.assertInputConflict(TensorType.fromSpec("tensor(a{},b{})"), + "", "", "inconsistent", "query(myTensor1)"); + tester.assertInputConflict(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "a", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(x[10])"), + "ab", "b", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "a", "", "inconsistent", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(x[10])"), + "b", "", "inconsistent", "query(myTensor1)"); + tester.assertInput(null, + "a", "", "bOnly", "query(myTensor1)"); + tester.assertInput(TensorType.fromSpec("tensor(a{},b{})"), + "ab", "", "bOnly", "query(myTensor1)"); + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTester.java b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTester.java new file mode 100644 index 00000000000..d5b4522f3aa --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/config/SchemaInfoTester.java @@ -0,0 +1,133 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.config; + +import com.yahoo.container.QrSearchersConfig; +import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; +import com.yahoo.search.Query; +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class SchemaInfoTester { + + private final SchemaInfo schemaInfo; + + SchemaInfoTester() { + this.schemaInfo = createSchemaInfo(); + } + + SchemaInfo schemaInfo() { return schemaInfo; } + + Query query(String sources, String restrict) { + Map<String, String> params = new HashMap<>(); + if ( ! sources.isEmpty()) + params.put("sources", sources); + if ( ! restrict.isEmpty()) + params.put("restrict", restrict); + return new Query.Builder().setSchemaInfo(schemaInfo) + .setRequestMap(params) + .build(); + } + + void assertInput(TensorType expectedType, String sources, String restrict, String rankProfile, String feature) { + assertEquals(expectedType, + schemaInfo.newSession(query(sources, restrict)).rankProfileInput(feature, rankProfile)); + } + + void assertInputConflict(TensorType expectedType, String sources, String restrict, String rankProfile, String feature) { + try { + assertInput(expectedType, sources, restrict, rankProfile, feature); + } + catch (IllegalArgumentException e) { + assertEquals("Conflicting input type declarations for '" + feature + "'", + e.getMessage().split(":")[0]); + } + } + + static SchemaInfo createSchemaInfo() { + List<Schema> schemas = new ArrayList<>(); + RankProfile common = new RankProfile.Builder("commonProfile") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor2)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor3)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor4)", TensorType.fromSpec("tensor<float>(x[5])")) + .build(); + schemas.add(new Schema.Builder("a") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + schemas.add(new Schema.Builder("b") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(x[10])")) + .build()) + .add(new RankProfile.Builder("bOnly") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + Map<String, List<String>> clusters = new HashMap<>(); + clusters.put("ab", List.of("a", "b")); + clusters.put("a", List.of("a")); + return new SchemaInfo(schemas, clusters); + } + + /** Creates the same schema info as createSchemaInfo from config objects. */ + static SchemaInfo createSchemaInfoFromConfig() { + var indexInfoConfig = new IndexInfoConfig.Builder(); + + var rankProfileCommon = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileCommon.name("commonProfile"); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(a{},b{})")); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor2)").type("tensor(x[2],y[2])")); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor3)").type("tensor(x[2],y[2])")); + rankProfileCommon.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor4)").type("tensor<float>(x[5])")); + + var documentDbInfoInfoConfig = new DocumentdbInfoConfig.Builder(); + + var documentDbA = new DocumentdbInfoConfig.Documentdb.Builder(); + documentDbA.name("a"); + documentDbA.rankprofile(rankProfileCommon); + var rankProfileInconsistentA = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileInconsistentA.name("inconsistent"); + rankProfileInconsistentA.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(a{},b{})")); + documentDbA.rankprofile(rankProfileInconsistentA); + documentDbInfoInfoConfig.documentdb(documentDbA); + + var documentDbB = new DocumentdbInfoConfig.Documentdb.Builder(); + documentDbB.name("b"); + documentDbB.rankprofile(rankProfileCommon); + var rankProfileInconsistentB = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileInconsistentB.name("inconsistent"); + rankProfileInconsistentB.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(x[10])")); + documentDbB.rankprofile(rankProfileInconsistentB); + var rankProfileBOnly = new DocumentdbInfoConfig.Documentdb.Rankprofile.Builder(); + rankProfileBOnly.name("bOnly"); + rankProfileBOnly.input(new DocumentdbInfoConfig.Documentdb.Rankprofile.Input.Builder().name("query(myTensor1)").type("tensor(a{},b{})")); + documentDbB.rankprofile(rankProfileBOnly); + documentDbInfoInfoConfig.documentdb(documentDbB); + + var qrSearchersConfig = new QrSearchersConfig.Builder(); + var clusterAB = new QrSearchersConfig.Searchcluster.Builder(); + clusterAB.name("ab"); + clusterAB.searchdef("a").searchdef("b"); + qrSearchersConfig.searchcluster(clusterAB); + var clusterA = new QrSearchersConfig.Searchcluster.Builder(); + clusterA.name("a"); + clusterA.searchdef("a"); + qrSearchersConfig.searchcluster(clusterA); + + return new SchemaInfo(indexInfoConfig.build(), documentDbInfoInfoConfig.build(), qrSearchersConfig.build()); + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java new file mode 100644 index 00000000000..1b10e4cd0ba --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -0,0 +1,300 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.query; + +import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.search.Query; +import com.yahoo.search.config.RankProfile; +import com.yahoo.search.config.Schema; +import com.yahoo.search.config.SchemaInfo; +import com.yahoo.search.query.profile.QueryProfile; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.compiled.CompiledQueryProfile; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +/** + * Tests queries towards rank profiles using input declarations. + * + * @author bratseth + */ +public class RankProfileInputTest { + + @Test + public void testTensorRankFeatureInRequest() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + + { + Query query = createTensor1Query(tensorString, "commonProfile", ""); + assertEquals(0, query.errors().size()); + assertEquals(Tensor.from(tensorString), query.properties().get("ranking.features.query(myTensor1)")); + assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get()); + } + + { // Partial resolution is sufficient + Query query = createTensor1Query(tensorString, "bOnly", ""); + assertEquals(0, query.errors().size()); + assertEquals(Tensor.from(tensorString), query.properties().get("ranking.features.query(myTensor1)")); + assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get()); + } + + { // Resolution is limited to the correct sources + Query query = createTensor1Query(tensorString, "bOnly", "sources=a"); + assertEquals(0, query.errors().size()); + assertEquals("Not converted to tensor", + tensorString, query.properties().get("ranking.features.query(myTensor1)")); + } + } + + @Test + public void testTensorRankFeatureInRequestInconsistentInput() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + try { + createTensor1Query(tensorString, "inconsistent", ""); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Conflicting input type declarations for 'query(myTensor1)': " + + "Declared as tensor(a{},b{}) in rank profile 'inconsistent' in schema 'a', " + + "and as tensor(x[10]) in rank profile 'inconsistent' in schema 'b'", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testTensorRankFeatureWithSourceResolution() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + + { + createTensor1Query(tensorString, "inconsistent", "sources=a"); + // Success: No exception + } + + try { + createTensor1Query(tensorString, "inconsistent", "sources=ab"); + fail("Excpected exception"); + } + catch (IllegalArgumentException e) { + // success + } + + { + createTensor1Query(tensorString, "inconsistent", "sources=a&restrict=a"); + // Success: No exception + } + } + + @Test + public void testTensorRankFeatureSetProgrammatically() { + String tensorString = "{{a:a1, b:b1}:1.0, {a:a2, b:b1}:2.0}}"; + Query query = new Query.Builder() + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles + .setRequest(HttpRequest.createTestRequest("?" + + "&ranking=commonProfile", + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .build(); + + query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString)); + assertEquals(Tensor.from(tensorString), query.getRanking().getFeatures().getTensor("query(myTensor1)").get()); + } + + @Test + public void testTensorRankFeatureSetProgrammaticallyWithWrongType() { + Query query = new Query.Builder() + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles + .setRequest(HttpRequest.createTestRequest("?" + + "&ranking=commonProfile", + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .build(); + + String tensorString = "tensor(x[3]):[0.1, 0.2, 0.3]"; + try { + query.getRanking().getFeatures().put("query(myTensor1)",Tensor.from(tensorString)); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "This input is declared in rank profile 'commonProfile' as tensor(a{},b{})", + Exceptions.toMessageString(e)); + } + try { + query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString)); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "Require a tensor of type tensor(a{},b{})", + Exceptions.toMessageString(e)); + } + } + + @Test + public void testUnembeddedTensorRankFeatureInRequest() { + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + Tensor embedding2 = Tensor.from("tensor<float>(x[5]):[1,2,3,4,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders); + assertEmbedQuery("embed('" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(\"" + text + "\")", embedding1, embedders); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb1, \"" + text + "\")", embedding1, embedders); + assertEmbedQueryFails("embed(emb2, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb2'. Valid embedders are emb1"); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders); + assertEmbedQueryFails("embed(emb3, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); + + // And with specified language + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders, Language.ENGLISH.languageCode()); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders, Language.ENGLISH.languageCode()); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); + } + + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { + return new Query.Builder() + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) // Use the instantiation path with query profiles + .setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features.query(myTensor1)") + + "=" + urlEncode(tensorString) + + "&ranking=" + profile + + "&" + additionalParams, + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .build(); + } + + private String urlEncode(String s) { + return URLEncoder.encode(s, StandardCharsets.UTF_8); + } + + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { + assertEmbedQuery(embed, expected, embedders, null); + } + + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + String languageParam = language == null ? "" : "&language=" + language; + String destination = "query(myTensor4)"; + + Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( + "?" + urlEncode("ranking.features." + destination) + + "=" + urlEncode(embed) + + "&ranking=commonProfile" + + languageParam, + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .setSchemaInfo(createSchemaInfo()) + .setQueryProfile(createQueryProfile()) + .setEmbedders(embedders) + .build(); + assertEquals(0, query.errors().size()); + assertEquals(expected, query.properties().get("ranking.features." + destination)); + assertEquals(expected, query.getRanking().getFeatures().getTensor(destination).get()); + } + + private void assertEmbedQueryFails(String embed, Tensor expected, Map<String, Embedder> embedders, String errMsg) { + Throwable t = assertThrows(IllegalArgumentException.class, () -> assertEmbedQuery(embed, expected, embedders)); + while (t != null) { + if (t.getMessage().equals(errMsg)) return; + t = t.getCause(); + } + fail("Error '" + errMsg + "' not thrown"); + } + + private CompiledQueryProfile createQueryProfile() { + var registry = new QueryProfileRegistry(); + registry.register(new QueryProfile("test")); + return registry.compile().findQueryProfile("test"); + } + + private SchemaInfo createSchemaInfo() { + List<Schema> schemas = new ArrayList<>(); + RankProfile common = new RankProfile.Builder("commonProfile") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor2)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor3)", TensorType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor4)", TensorType.fromSpec("tensor<float>(x[5])")) + .build(); + schemas.add(new Schema.Builder("a") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + schemas.add(new Schema.Builder("b") + .add(common) + .add(new RankProfile.Builder("inconsistent") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(x[10])")) + .build()) + .add(new RankProfile.Builder("bOnly") + .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .build()) + .build()); + Map<String, List<String>> clusters = new HashMap<>(); + clusters.put("ab", List.of("a", "b")); + clusters.put("a", List.of("a")); + return new SchemaInfo(schemas, clusters); + } + + private static final class MockEmbedder implements Embedder { + + private final String expectedText; + private final Language expectedLanguage; + private final Tensor tensorToReturn; + + public MockEmbedder(String expectedText, + Language expectedLanguage, + Tensor tensorToReturn) { + this.expectedText = expectedText; + this.expectedLanguage = expectedLanguage; + this.tensorToReturn = tensorToReturn; + } + + @Override + public List<Integer> embed(String text, Embedder.Context context) { + fail("Unexpected call"); + return null; + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + assertEquals(expectedText, text); + assertEquals(expectedLanguage, context.getLanguage()); + assertEquals(tensorToReturn.type(), tensorType); + return tensorToReturn; + } + + } + +} |