aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-03-21 14:16:42 +0100
committerLester Solbakken <lesters@oath.com>2022-03-21 14:16:42 +0100
commit8a64a50ac9f0cbea18f0c1a8e1ef482d3311e873 (patch)
tree19853dccb4c68714885e7cb73d32e1f191ef306a /container-search/src/main/java/com/yahoo
parentc5e464f1a6da3a74113d775805187a547074a2da (diff)
Add embedder selection argument to query parameter transformation
Diffstat (limited to 'container-search/src/main/java/com/yahoo')
-rw-r--r--container-search/src/main/java/com/yahoo/search/Query.java33
-rw-r--r--container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java46
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java50
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java10
6 files changed, 136 insertions, 31 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java
index 83fa18d847f..b7a0fcb5dc3 100644
--- a/container-search/src/main/java/com/yahoo/search/Query.java
+++ b/container-search/src/main/java/com/yahoo/search/Query.java
@@ -337,7 +337,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
public Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile) {
super(new QueryPropertyAliases(propertyAliases));
this.httpRequest = request;
- init(requestMap, queryProfile, Embedder.throwsOnUse, ZoneInfo.defaultInfo());
+ init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo());
}
// TODO: Deprecate most constructors above here
@@ -346,31 +346,31 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
this(builder.getRequest(),
builder.getRequestMap(),
builder.getQueryProfile(),
- builder.getEmbedder(),
+ builder.getEmbedders(),
builder.getZoneInfo());
}
- private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Embedder embedder,
+ private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders,
ZoneInfo zoneInfo) {
super(new QueryPropertyAliases(propertyAliases));
this.httpRequest = request;
- init(requestMap, queryProfile, embedder, zoneInfo);
+ init(requestMap, queryProfile, embedders, zoneInfo);
}
private void init(Map<String, String> requestMap,
CompiledQueryProfile queryProfile,
- Embedder embedder,
+ Map<String, Embedder> embedders,
ZoneInfo zoneInfo) {
startTime = httpRequest.getJDiscRequest().creationTime(TimeUnit.MILLISECONDS);
if (queryProfile != null) {
// Move all request parameters to the query profile just to validate that the parameter settings are legal
- Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedder);
+ Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedders);
properties().chain(queryProfileProperties);
// TODO: Just checking legality rather than actually setting would be faster
setPropertiesFromRequestMap(requestMap, properties(), true); // Adds errors to the query for illegal set attempts
// Create the full chain
- properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedder)).
+ properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedders)).
chain(new ModelObjectMap()).
chain(new RequestContextProperties(requestMap, zoneInfo)).
chain(queryProfileProperties).
@@ -389,7 +389,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
}
else { // bypass these complications if there is no query profile to get values from and validate against
properties().
- chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedder)).
+ chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedders)).
chain(new PropertyMap()).
chain(new DefaultProperties());
setPropertiesFromRequestMap(requestMap, properties(), false);
@@ -1131,7 +1131,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
private HttpRequest request = null;
private Map<String, String> requestMap = null;
private CompiledQueryProfile queryProfile = null;
- private Embedder embedder = Embedder.throwsOnUse;
+ private Map<String, Embedder> embedders = Embedder.throwsOnUse.asMap();
private ZoneInfo zoneInfo = ZoneInfo.defaultInfo();
public Builder setRequest(String query) {
@@ -1171,11 +1171,22 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
public CompiledQueryProfile getQueryProfile() { return queryProfile; }
public Builder setEmbedder(Embedder embedder) {
- this.embedder = embedder;
+ return setEmbedders(Map.of(Embedder.defaultEmbedderName, embedder));
+ }
+
+ public Builder setEmbedders(Map<String, Embedder> embedders) {
+ this.embedders = embedders;
return this;
}
- public Embedder getEmbedder() { return embedder; }
+ public Embedder getEmbedder() {
+ if (embedders.size() != 1) {
+ throw new IllegalArgumentException("Attempt to get single embedder but multiple exists.");
+ }
+ return embedders.entrySet().stream().findFirst().get().getValue();
+ }
+
+ public Map<String, Embedder> getEmbedders() { return embedders; }
public Builder setZoneInfo(ZoneInfo zoneInfo) {
this.zoneInfo = zoneInfo;
diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
index b65953935f0..86ff34f659b 100644
--- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
+++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
@@ -3,6 +3,7 @@ package com.yahoo.search.handler;
import com.google.inject.Inject;
import com.yahoo.collections.Tuple2;
+import com.yahoo.component.ComponentId;
import com.yahoo.component.ComponentSpecification;
import com.yahoo.component.Vtag;
import com.yahoo.component.chain.Chain;
@@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric;
import com.yahoo.jdisc.Request;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
+import com.yahoo.language.provider.DefaultEmbedderProvider;
import com.yahoo.net.HostName;
import com.yahoo.net.UriTools;
import com.yahoo.prelude.query.parser.ParseException;
@@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
* Handles search request.
@@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler {
private final Optional<String> hostResponseHeaderKey;
private final String selfHostname = HostName.getLocalhost();
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final ExecutionFactory executionFactory;
private final AtomicLong numRequestsLeftToTrace;
@@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler {
ContainerThreadPool threadpool,
CompiledQueryProfileRegistry queryProfileRegistry,
ContainerHttpConfig config,
- Embedder embedder,
+ ComponentRegistry<Embedder> embedders,
ExecutionFactory executionFactory,
ZoneInfo zoneInfo) {
- this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory,
+ this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory,
config.numQueriesToTraceOnDebugAfterConstruction(),
config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()),
zoneInfo);
@@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler {
CompiledQueryProfileRegistry queryProfileRegistry,
ExecutionFactory executionFactory,
Optional<String> hostResponseHeaderKey) {
- this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse,
+ this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse),
executionFactory, 0, hostResponseHeaderKey,
ZoneInfo.defaultInfo());
}
@@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler {
long numQueriesToTraceOnDebugAfterStartup,
Optional<String> hostResponseHeaderKey,
ZoneInfo zoneInfo) {
+ this(metric, executor, queryProfileRegistry, toRegistry(embedder),
+ executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey,
+ ZoneInfo.defaultInfo());
+ }
+
+ private SearchHandler(Metric metric,
+ Executor executor,
+ CompiledQueryProfileRegistry queryProfileRegistry,
+ ComponentRegistry<Embedder> embedders,
+ ExecutionFactory executionFactory,
+ long numQueriesToTraceOnDebugAfterStartup,
+ Optional<String> hostResponseHeaderKey,
+ ZoneInfo zoneInfo) {
super(executor, metric, true);
+
log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this));
this.queryProfileRegistry = queryProfileRegistry;
- this.embedder = embedder;
+ this.embedders = toMap(embedders);
this.executionFactory = executionFactory;
this.maxThreads = examineExecutor(executor);
@@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler {
Query query = new Query.Builder().setRequest(request)
.setRequestMap(requestMap)
.setQueryProfile(queryProfile)
- .setEmbedder(embedder)
+ .setEmbedders(embedders)
.setZoneInfo(zoneInfo)
.build();
@@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler {
.build();
}
+ private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
+ var map = embedders.allComponentsById().entrySet().stream()
+ .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
+ if (map.size() > 1) {
+ map.remove(DefaultEmbedderProvider.class.getName());
+ // Ideally, this should be handled by dependency injection, however for now this workaround is necessary.
+ }
+ return Collections.unmodifiableMap(map);
+ }
+
+ private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) {
+ ComponentRegistry<Embedder> emb = new ComponentRegistry<>();
+ emb.register(new ComponentId(Embedder.defaultEmbedderName), embedder);
+ return emb;
+ }
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
index f58395fd5bb..6e778a0fac6 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
@@ -30,7 +30,7 @@ import java.util.Map;
public class QueryProfileProperties extends Properties {
private final CompiledQueryProfile profile;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
// Note: The priority order is: values has precedence over references
@@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties {
private List<Pair<CompoundName, CompiledQueryProfile>> references = null;
public QueryProfileProperties(CompiledQueryProfile profile) {
- this(profile, Embedder.throwsOnUse);
+ this(profile, Embedder.throwsOnUse.asMap());
}
- /** Creates an instance from a profile, throws an exception if the given profile is null */
public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) {
+ this(profile, Map.of(Embedder.defaultEmbedderName, embedder));
+ }
+
+ /** Creates an instance from a profile, throws an exception if the given profile is null */
+ public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) {
Validator.ensureNotNull("The profile wrapped by this cannot be null", profile);
this.profile = profile;
- this.embedder = embedder;
+ this.embedders = embedders;
}
/** Returns the query profile backing this, or null if none */
@@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties {
if (fieldDescription != null) {
if (i == name.size() - 1) { // at the end of the path, check the assignment type
- var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context);
+ var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context);
var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext);
if (convertedValue == null
&& fieldDescription.getType() instanceof QueryProfileFieldType
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
index 8dfb67a9d5f..1fc405051ac 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
@@ -14,14 +14,20 @@ public class ConversionContext {
private final String destination;
private final CompiledQueryProfileRegistry registry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final Language language;
public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder,
Map<String, String> context) {
+ this(destination, registry, Map.of(Embedder.defaultEmbedderName, embedder), context);
+ }
+
+ public ConversionContext(String destination, CompiledQueryProfileRegistry registry,
+ Map<String, Embedder> embedders,
+ Map<String, String> context) {
this.destination = destination;
this.registry = registry;
- this.embedder = embedder;
+ this.embedders = embedders;
this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language"))
: Language.UNKNOWN;
}
@@ -33,14 +39,14 @@ public class ConversionContext {
CompiledQueryProfileRegistry registry() {return registry;}
/** Returns the configured embedder, never null */
- Embedder embedder() { return embedder; }
+ Map<String, Embedder> embedders() { return embedders; }
/** Returns the language, which is never null but may be UNKNOWN */
Language language() { return language; }
/** Returns an empty context */
public static ConversionContext empty() {
- return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of());
+ return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of());
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index d6676db3774..6f1cfccc16b 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
/**
* A tensor field type in a query profile
*
@@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType;
*/
public class TensorFieldType extends FieldType {
+ private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])");
+
private final TensorType type;
/** Creates a tensor field type with information about the kind of tensor this will hold */
@@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType {
private Tensor encode(String s, ConversionContext context) {
if ( ! s.endsWith(")"))
throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
- String text = s.substring("embed(".length(), s.length() - 1);
- return context.embedder().embed(text, toEmbedderContext(context), type);
+ String argument = s.substring("embed(".length(), s.length() - 1);
+ Embedder embedder;
+
+ // Check if arguments specifies an embedder with the format embed(embedder, "text to encode")
+ Matcher matcher = embedderArgumentRegexp.matcher(argument);
+ if (matcher.matches()) {
+ String embedderId = matcher.group(1);
+ argument = matcher.group(2);
+ if (!context.embedders().containsKey(embedderId)) {
+ throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ }
+ embedder = context.embedders().get(embedderId);
+ } else if (context.embedders().size() == 0) {
+ throw new IllegalStateException("No embedders provided"); // should never happen
+ } else if (context.embedders().size() > 1) {
+ throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ } else {
+ embedder = context.embedders().entrySet().stream().findFirst().get().getValue();
+ }
+
+ return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type);
+ }
+
+ private static String removeQuotes(String s) {
+ if (s.startsWith("'") && s.endsWith("'")) {
+ return s.substring(1, s.length() - 1);
+ }
+ if (s.startsWith("\"") && s.endsWith("\"")) {
+ return s.substring(1, s.length() - 1);
+ }
+ return s;
+ }
+
+ private static String validEmbedders(Map<String, Embedder> embedders) {
+ List<String> embedderIds = new ArrayList<>();
+ embedders.forEach((key, value) -> embedderIds.add(key));
+ embedderIds.sort(null);
+ return String.join(",", embedderIds);
}
private Embedder.Context toEmbedderContext(ConversionContext context) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index 243915662d2..dc901589cde 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -34,12 +34,16 @@ public class QueryProperties extends Properties {
private Query query;
private final CompiledQueryProfileRegistry profileRegistry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) {
+ this(query, profileRegistry, Map.of(Embedder.defaultEmbedderName, embedder));
+ }
+
+ public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) {
this.query = query;
this.profileRegistry = profileRegistry;
- this.embedder = embedder;
+ this.embedders = embedders;
}
public void setParentQuery(Query query) {
@@ -394,7 +398,7 @@ public class QueryProperties extends Properties {
if (type == null) return value; // no type info -> keep as string
FieldDescription field = type.getField(key);
if (field == null) return value; // ditto
- return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context));
+ return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context));
}
private void throwIllegalParameter(String key,String namespace) {