diff options
author | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:42 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:42 +0100 |
commit | 8a64a50ac9f0cbea18f0c1a8e1ef482d3311e873 (patch) | |
tree | 19853dccb4c68714885e7cb73d32e1f191ef306a /container-search/src/main | |
parent | c5e464f1a6da3a74113d775805187a547074a2da (diff) |
Add embedder selection argument to query parameter transformation
Diffstat (limited to 'container-search/src/main')
6 files changed, 136 insertions, 31 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index 83fa18d847f..b7a0fcb5dc3 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -337,7 +337,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, Embedder.throwsOnUse, ZoneInfo.defaultInfo()); + init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo()); } // TODO: Deprecate most constructors above here @@ -346,31 +346,31 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { this(builder.getRequest(), builder.getRequestMap(), builder.getQueryProfile(), - builder.getEmbedder(), + builder.getEmbedders(), builder.getZoneInfo()); } - private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Embedder embedder, + private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders, ZoneInfo zoneInfo) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, embedder, zoneInfo); + init(requestMap, queryProfile, embedders, zoneInfo); } private void init(Map<String, String> requestMap, CompiledQueryProfile queryProfile, - Embedder embedder, + Map<String, Embedder> embedders, ZoneInfo zoneInfo) { startTime = httpRequest.getJDiscRequest().creationTime(TimeUnit.MILLISECONDS); if (queryProfile != null) { // Move all request parameters to the query profile just to validate that the parameter settings are legal - Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedder); + Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedders); properties().chain(queryProfileProperties); // TODO: Just checking legality rather than actually setting would be faster setPropertiesFromRequestMap(requestMap, properties(), true); // Adds errors to the query for illegal set attempts // Create the full chain - properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedder)). + properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedders)). chain(new ModelObjectMap()). chain(new RequestContextProperties(requestMap, zoneInfo)). chain(queryProfileProperties). @@ -389,7 +389,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { } else { // bypass these complications if there is no query profile to get values from and validate against properties(). - chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedder)). + chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedders)). chain(new PropertyMap()). chain(new DefaultProperties()); setPropertiesFromRequestMap(requestMap, properties(), false); @@ -1131,7 +1131,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { private HttpRequest request = null; private Map<String, String> requestMap = null; private CompiledQueryProfile queryProfile = null; - private Embedder embedder = Embedder.throwsOnUse; + private Map<String, Embedder> embedders = Embedder.throwsOnUse.asMap(); private ZoneInfo zoneInfo = ZoneInfo.defaultInfo(); public Builder setRequest(String query) { @@ -1171,11 +1171,22 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public CompiledQueryProfile getQueryProfile() { return queryProfile; } public Builder setEmbedder(Embedder embedder) { - this.embedder = embedder; + return setEmbedders(Map.of(Embedder.defaultEmbedderName, embedder)); + } + + public Builder setEmbedders(Map<String, Embedder> embedders) { + this.embedders = embedders; return this; } - public Embedder getEmbedder() { return embedder; } + public Embedder getEmbedder() { + if (embedders.size() != 1) { + throw new IllegalArgumentException("Attempt to get single embedder but multiple exists."); + } + return embedders.entrySet().stream().findFirst().get().getValue(); + } + + public Map<String, Embedder> getEmbedders() { return embedders; } public Builder setZoneInfo(ZoneInfo zoneInfo) { this.zoneInfo = zoneInfo; diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java index b65953935f0..86ff34f659b 100644 --- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java +++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java @@ -3,6 +3,7 @@ package com.yahoo.search.handler; import com.google.inject.Inject; import com.yahoo.collections.Tuple2; +import com.yahoo.component.ComponentId; import com.yahoo.component.ComponentSpecification; import com.yahoo.component.Vtag; import com.yahoo.component.chain.Chain; @@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.Request; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.provider.DefaultEmbedderProvider; import com.yahoo.net.HostName; import com.yahoo.net.UriTools; import com.yahoo.prelude.query.parser.ParseException; @@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Handles search request. @@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler { private final Optional<String> hostResponseHeaderKey; private final String selfHostname = HostName.getLocalhost(); - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final ExecutionFactory executionFactory; private final AtomicLong numRequestsLeftToTrace; @@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler { ContainerThreadPool threadpool, CompiledQueryProfileRegistry queryProfileRegistry, ContainerHttpConfig config, - Embedder embedder, + ComponentRegistry<Embedder> embedders, ExecutionFactory executionFactory, ZoneInfo zoneInfo) { - this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory, + this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory, config.numQueriesToTraceOnDebugAfterConstruction(), config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()), zoneInfo); @@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler { CompiledQueryProfileRegistry queryProfileRegistry, ExecutionFactory executionFactory, Optional<String> hostResponseHeaderKey) { - this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse, + this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse), executionFactory, 0, hostResponseHeaderKey, ZoneInfo.defaultInfo()); } @@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler { long numQueriesToTraceOnDebugAfterStartup, Optional<String> hostResponseHeaderKey, ZoneInfo zoneInfo) { + this(metric, executor, queryProfileRegistry, toRegistry(embedder), + executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey, + ZoneInfo.defaultInfo()); + } + + private SearchHandler(Metric metric, + Executor executor, + CompiledQueryProfileRegistry queryProfileRegistry, + ComponentRegistry<Embedder> embedders, + ExecutionFactory executionFactory, + long numQueriesToTraceOnDebugAfterStartup, + Optional<String> hostResponseHeaderKey, + ZoneInfo zoneInfo) { super(executor, metric, true); + log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this)); this.queryProfileRegistry = queryProfileRegistry; - this.embedder = embedder; + this.embedders = toMap(embedders); this.executionFactory = executionFactory; this.maxThreads = examineExecutor(executor); @@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler { Query query = new Query.Builder().setRequest(request) .setRequestMap(requestMap) .setQueryProfile(queryProfile) - .setEmbedder(embedder) + .setEmbedders(embedders) .setZoneInfo(zoneInfo) .build(); @@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler { .build(); } + private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) { + var map = embedders.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + if (map.size() > 1) { + map.remove(DefaultEmbedderProvider.class.getName()); + // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. + } + return Collections.unmodifiableMap(map); + } + + private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) { + ComponentRegistry<Embedder> emb = new ComponentRegistry<>(); + emb.register(new ComponentId(Embedder.defaultEmbedderName), embedder); + return emb; + } + } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java index f58395fd5bb..6e778a0fac6 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java @@ -30,7 +30,7 @@ import java.util.Map; public class QueryProfileProperties extends Properties { private final CompiledQueryProfile profile; - private final Embedder embedder; + private final Map<String, Embedder> embedders; // Note: The priority order is: values has precedence over references @@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties { private List<Pair<CompoundName, CompiledQueryProfile>> references = null; public QueryProfileProperties(CompiledQueryProfile profile) { - this(profile, Embedder.throwsOnUse); + this(profile, Embedder.throwsOnUse.asMap()); } - /** Creates an instance from a profile, throws an exception if the given profile is null */ public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) { + this(profile, Map.of(Embedder.defaultEmbedderName, embedder)); + } + + /** Creates an instance from a profile, throws an exception if the given profile is null */ + public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) { Validator.ensureNotNull("The profile wrapped by this cannot be null", profile); this.profile = profile; - this.embedder = embedder; + this.embedders = embedders; } /** Returns the query profile backing this, or null if none */ @@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties { if (fieldDescription != null) { if (i == name.size() - 1) { // at the end of the path, check the assignment type - var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context); + var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context); var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext); if (convertedValue == null && fieldDescription.getType() instanceof QueryProfileFieldType diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index 8dfb67a9d5f..1fc405051ac 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -14,14 +14,20 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, Map<String, String> context) { + this(destination, registry, Map.of(Embedder.defaultEmbedderName, embedder), context); + } + + public ConversionContext(String destination, CompiledQueryProfileRegistry registry, + Map<String, Embedder> embedders, + Map<String, String> context) { this.destination = destination; this.registry = registry; - this.embedder = embedder; + this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; } @@ -33,14 +39,14 @@ public class ConversionContext { CompiledQueryProfileRegistry registry() {return registry;} /** Returns the configured embedder, never null */ - Embedder embedder() { return embedder; } + Map<String, Embedder> embedders() { return embedders; } /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } /** Returns an empty context */ public static ConversionContext empty() { - return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of()); + return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index d6676db3774..6f1cfccc16b 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + /** * A tensor field type in a query profile * @@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType; */ public class TensorFieldType extends FieldType { + private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private final TensorType type; /** Creates a tensor field type with information about the kind of tensor this will hold */ @@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType { private Tensor encode(String s, ConversionContext context) { if ( ! s.endsWith(")")) throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); - String text = s.substring("embed(".length(), s.length() - 1); - return context.embedder().embed(text, toEmbedderContext(context), type); + String argument = s.substring("embed(".length(), s.length() - 1); + Embedder embedder; + + // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") + Matcher matcher = embedderArgumentRegexp.matcher(argument); + if (matcher.matches()) { + String embedderId = matcher.group(1); + argument = matcher.group(2); + if (!context.embedders().containsKey(embedderId)) { + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } + embedder = context.embedders().get(embedderId); + } else if (context.embedders().size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } else if (context.embedders().size() > 1) { + throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } else { + embedder = context.embedders().entrySet().stream().findFirst().get().getValue(); + } + + return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type); + } + + private static String removeQuotes(String s) { + if (s.startsWith("'") && s.endsWith("'")) { + return s.substring(1, s.length() - 1); + } + if (s.startsWith("\"") && s.endsWith("\"")) { + return s.substring(1, s.length() - 1); + } + return s; + } + + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); } private Embedder.Context toEmbedderContext(ConversionContext context) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 243915662d2..dc901589cde 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -34,12 +34,16 @@ public class QueryProperties extends Properties { private Query query; private final CompiledQueryProfileRegistry profileRegistry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) { + this(query, profileRegistry, Map.of(Embedder.defaultEmbedderName, embedder)); + } + + public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) { this.query = query; this.profileRegistry = profileRegistry; - this.embedder = embedder; + this.embedders = embedders; } public void setParentQuery(Query query) { @@ -394,7 +398,7 @@ public class QueryProperties extends Properties { if (type == null) return value; // no type info -> keep as string FieldDescription field = type.getField(key); if (field == null) return value; // ditto - return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context)); + return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context)); } private void throwIllegalParameter(String key,String namespace) { |