diff options
author | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:42 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:42 +0100 |
commit | 8a64a50ac9f0cbea18f0c1a8e1ef482d3311e873 (patch) | |
tree | 19853dccb4c68714885e7cb73d32e1f191ef306a /container-search/src/main/java/com/yahoo/search/handler | |
parent | c5e464f1a6da3a74113d775805187a547074a2da (diff) |
Add embedder selection argument to query parameter transformation
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/handler')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java | 46 |
1 files changed, 40 insertions, 6 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java index b65953935f0..86ff34f659b 100644 --- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java +++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java @@ -3,6 +3,7 @@ package com.yahoo.search.handler; import com.google.inject.Inject; import com.yahoo.collections.Tuple2; +import com.yahoo.component.ComponentId; import com.yahoo.component.ComponentSpecification; import com.yahoo.component.Vtag; import com.yahoo.component.chain.Chain; @@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.Request; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.provider.DefaultEmbedderProvider; import com.yahoo.net.HostName; import com.yahoo.net.UriTools; import com.yahoo.prelude.query.parser.ParseException; @@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Handles search request. @@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler { private final Optional<String> hostResponseHeaderKey; private final String selfHostname = HostName.getLocalhost(); - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final ExecutionFactory executionFactory; private final AtomicLong numRequestsLeftToTrace; @@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler { ContainerThreadPool threadpool, CompiledQueryProfileRegistry queryProfileRegistry, ContainerHttpConfig config, - Embedder embedder, + ComponentRegistry<Embedder> embedders, ExecutionFactory executionFactory, ZoneInfo zoneInfo) { - this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory, + this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory, config.numQueriesToTraceOnDebugAfterConstruction(), config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()), zoneInfo); @@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler { CompiledQueryProfileRegistry queryProfileRegistry, ExecutionFactory executionFactory, Optional<String> hostResponseHeaderKey) { - this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse, + this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse), executionFactory, 0, hostResponseHeaderKey, ZoneInfo.defaultInfo()); } @@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler { long numQueriesToTraceOnDebugAfterStartup, Optional<String> hostResponseHeaderKey, ZoneInfo zoneInfo) { + this(metric, executor, queryProfileRegistry, toRegistry(embedder), + executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey, + ZoneInfo.defaultInfo()); + } + + private SearchHandler(Metric metric, + Executor executor, + CompiledQueryProfileRegistry queryProfileRegistry, + ComponentRegistry<Embedder> embedders, + ExecutionFactory executionFactory, + long numQueriesToTraceOnDebugAfterStartup, + Optional<String> hostResponseHeaderKey, + ZoneInfo zoneInfo) { super(executor, metric, true); + log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this)); this.queryProfileRegistry = queryProfileRegistry; - this.embedder = embedder; + this.embedders = toMap(embedders); this.executionFactory = executionFactory; this.maxThreads = examineExecutor(executor); @@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler { Query query = new Query.Builder().setRequest(request) .setRequestMap(requestMap) .setQueryProfile(queryProfile) - .setEmbedder(embedder) + .setEmbedders(embedders) .setZoneInfo(zoneInfo) .build(); @@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler { .build(); } + private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) { + var map = embedders.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + if (map.size() > 1) { + map.remove(DefaultEmbedderProvider.class.getName()); + // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. + } + return Collections.unmodifiableMap(map); + } + + private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) { + ComponentRegistry<Embedder> emb = new ComponentRegistry<>(); + emb.register(new ComponentId(Embedder.defaultEmbedderName), embedder); + return emb; + } + } |