summaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/handler
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-03-21 14:16:42 +0100
committerLester Solbakken <lesters@oath.com>2022-03-21 14:16:42 +0100
commit8a64a50ac9f0cbea18f0c1a8e1ef482d3311e873 (patch)
tree19853dccb4c68714885e7cb73d32e1f191ef306a /container-search/src/main/java/com/yahoo/search/handler
parentc5e464f1a6da3a74113d775805187a547074a2da (diff)
Add embedder selection argument to query parameter transformation
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/handler')
-rw-r--r--container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java46
1 files changed, 40 insertions, 6 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
index b65953935f0..86ff34f659b 100644
--- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
+++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
@@ -3,6 +3,7 @@ package com.yahoo.search.handler;
import com.google.inject.Inject;
import com.yahoo.collections.Tuple2;
+import com.yahoo.component.ComponentId;
import com.yahoo.component.ComponentSpecification;
import com.yahoo.component.Vtag;
import com.yahoo.component.chain.Chain;
@@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric;
import com.yahoo.jdisc.Request;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
+import com.yahoo.language.provider.DefaultEmbedderProvider;
import com.yahoo.net.HostName;
import com.yahoo.net.UriTools;
import com.yahoo.prelude.query.parser.ParseException;
@@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
* Handles search request.
@@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler {
private final Optional<String> hostResponseHeaderKey;
private final String selfHostname = HostName.getLocalhost();
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final ExecutionFactory executionFactory;
private final AtomicLong numRequestsLeftToTrace;
@@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler {
ContainerThreadPool threadpool,
CompiledQueryProfileRegistry queryProfileRegistry,
ContainerHttpConfig config,
- Embedder embedder,
+ ComponentRegistry<Embedder> embedders,
ExecutionFactory executionFactory,
ZoneInfo zoneInfo) {
- this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory,
+ this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory,
config.numQueriesToTraceOnDebugAfterConstruction(),
config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()),
zoneInfo);
@@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler {
CompiledQueryProfileRegistry queryProfileRegistry,
ExecutionFactory executionFactory,
Optional<String> hostResponseHeaderKey) {
- this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse,
+ this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse),
executionFactory, 0, hostResponseHeaderKey,
ZoneInfo.defaultInfo());
}
@@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler {
long numQueriesToTraceOnDebugAfterStartup,
Optional<String> hostResponseHeaderKey,
ZoneInfo zoneInfo) {
+ this(metric, executor, queryProfileRegistry, toRegistry(embedder),
+ executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey,
+ ZoneInfo.defaultInfo());
+ }
+
+ private SearchHandler(Metric metric,
+ Executor executor,
+ CompiledQueryProfileRegistry queryProfileRegistry,
+ ComponentRegistry<Embedder> embedders,
+ ExecutionFactory executionFactory,
+ long numQueriesToTraceOnDebugAfterStartup,
+ Optional<String> hostResponseHeaderKey,
+ ZoneInfo zoneInfo) {
super(executor, metric, true);
+
log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this));
this.queryProfileRegistry = queryProfileRegistry;
- this.embedder = embedder;
+ this.embedders = toMap(embedders);
this.executionFactory = executionFactory;
this.maxThreads = examineExecutor(executor);
@@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler {
Query query = new Query.Builder().setRequest(request)
.setRequestMap(requestMap)
.setQueryProfile(queryProfile)
- .setEmbedder(embedder)
+ .setEmbedders(embedders)
.setZoneInfo(zoneInfo)
.build();
@@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler {
.build();
}
+ private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
+ var map = embedders.allComponentsById().entrySet().stream()
+ .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
+ if (map.size() > 1) {
+ map.remove(DefaultEmbedderProvider.class.getName());
+ // Ideally, this should be handled by dependency injection, however for now this workaround is necessary.
+ }
+ return Collections.unmodifiableMap(map);
+ }
+
+ private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) {
+ ComponentRegistry<Embedder> emb = new ComponentRegistry<>();
+ emb.register(new ComponentId(Embedder.defaultEmbedderName), embedder);
+ return emb;
+ }
+
}