summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2022-03-24 09:41:25 +0100
committerGitHub <noreply@github.com>2022-03-24 09:41:25 +0100
commit0f0d8f84e5ff63f3fb01c650d38bbd0800150941 (patch)
treee8bc8987174d84702e430577e83c0d71661ba958 /container-search
parentb28e1c6946df008a3cf802b0b7b8931f6b9b2f6a (diff)
parent430aca046a032fd7f5aa2de6f6f6bb706d6de624 (diff)
Merge pull request #21761 from vespa-engine/lesters/multiple-embedders
Lesters/multiple embedders
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json7
-rw-r--r--container-search/src/main/java/com/yahoo/search/Query.java33
-rw-r--r--container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java46
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java50
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java10
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java113
8 files changed, 215 insertions, 72 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 5b990ca8758..b7aa1a8d0ef 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -1868,7 +1868,9 @@
"public com.yahoo.search.Query$Builder setQueryProfile(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)",
"public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()",
"public com.yahoo.search.Query$Builder setEmbedder(com.yahoo.language.process.Embedder)",
+ "public com.yahoo.search.Query$Builder setEmbedders(java.util.Map)",
"public com.yahoo.language.process.Embedder getEmbedder()",
+ "public java.util.Map getEmbedders()",
"public com.yahoo.search.Query$Builder setZoneInfo(ai.vespa.cloud.ZoneInfo)",
"public ai.vespa.cloud.ZoneInfo getZoneInfo()",
"public com.yahoo.search.Query build()"
@@ -4354,7 +4356,7 @@
"public"
],
"methods": [
- "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)",
+ "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.component.provider.ComponentRegistry, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)",
"public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)",
"public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)",
"public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.container.logging.AccessLog, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.search.searchchain.ExecutionFactory)",
@@ -5993,6 +5995,7 @@
"methods": [
"public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)",
"public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, com.yahoo.language.process.Embedder)",
+ "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, java.util.Map)",
"public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()",
"public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)",
"public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)",
@@ -6368,6 +6371,7 @@
],
"methods": [
"public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder, java.util.Map)",
+ "public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map, java.util.Map)",
"public java.lang.String destination()",
"public static com.yahoo.search.query.profile.types.ConversionContext empty()"
],
@@ -6642,6 +6646,7 @@
],
"methods": [
"public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder)",
+ "public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map)",
"public void setParentQuery(com.yahoo.search.Query)",
"public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)",
"public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)",
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java
index 83fa18d847f..553c73dac17 100644
--- a/container-search/src/main/java/com/yahoo/search/Query.java
+++ b/container-search/src/main/java/com/yahoo/search/Query.java
@@ -337,7 +337,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
public Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile) {
super(new QueryPropertyAliases(propertyAliases));
this.httpRequest = request;
- init(requestMap, queryProfile, Embedder.throwsOnUse, ZoneInfo.defaultInfo());
+ init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo());
}
// TODO: Deprecate most constructors above here
@@ -346,31 +346,31 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
this(builder.getRequest(),
builder.getRequestMap(),
builder.getQueryProfile(),
- builder.getEmbedder(),
+ builder.getEmbedders(),
builder.getZoneInfo());
}
- private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Embedder embedder,
+ private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders,
ZoneInfo zoneInfo) {
super(new QueryPropertyAliases(propertyAliases));
this.httpRequest = request;
- init(requestMap, queryProfile, embedder, zoneInfo);
+ init(requestMap, queryProfile, embedders, zoneInfo);
}
private void init(Map<String, String> requestMap,
CompiledQueryProfile queryProfile,
- Embedder embedder,
+ Map<String, Embedder> embedders,
ZoneInfo zoneInfo) {
startTime = httpRequest.getJDiscRequest().creationTime(TimeUnit.MILLISECONDS);
if (queryProfile != null) {
// Move all request parameters to the query profile just to validate that the parameter settings are legal
- Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedder);
+ Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedders);
properties().chain(queryProfileProperties);
// TODO: Just checking legality rather than actually setting would be faster
setPropertiesFromRequestMap(requestMap, properties(), true); // Adds errors to the query for illegal set attempts
// Create the full chain
- properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedder)).
+ properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedders)).
chain(new ModelObjectMap()).
chain(new RequestContextProperties(requestMap, zoneInfo)).
chain(queryProfileProperties).
@@ -389,7 +389,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
}
else { // bypass these complications if there is no query profile to get values from and validate against
properties().
- chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedder)).
+ chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedders)).
chain(new PropertyMap()).
chain(new DefaultProperties());
setPropertiesFromRequestMap(requestMap, properties(), false);
@@ -1131,7 +1131,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
private HttpRequest request = null;
private Map<String, String> requestMap = null;
private CompiledQueryProfile queryProfile = null;
- private Embedder embedder = Embedder.throwsOnUse;
+ private Map<String, Embedder> embedders = Embedder.throwsOnUse.asMap();
private ZoneInfo zoneInfo = ZoneInfo.defaultInfo();
public Builder setRequest(String query) {
@@ -1171,11 +1171,22 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
public CompiledQueryProfile getQueryProfile() { return queryProfile; }
public Builder setEmbedder(Embedder embedder) {
- this.embedder = embedder;
+ return setEmbedders(Map.of(Embedder.defaultEmbedderId, embedder));
+ }
+
+ public Builder setEmbedders(Map<String, Embedder> embedders) {
+ this.embedders = embedders;
return this;
}
- public Embedder getEmbedder() { return embedder; }
+ public Embedder getEmbedder() {
+ if (embedders.size() != 1) {
+ throw new IllegalArgumentException("Attempt to get single embedder but multiple exists.");
+ }
+ return embedders.entrySet().stream().findFirst().get().getValue();
+ }
+
+ public Map<String, Embedder> getEmbedders() { return embedders; }
public Builder setZoneInfo(ZoneInfo zoneInfo) {
this.zoneInfo = zoneInfo;
diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
index b65953935f0..af6374ba245 100644
--- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
+++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
@@ -3,6 +3,7 @@ package com.yahoo.search.handler;
import com.google.inject.Inject;
import com.yahoo.collections.Tuple2;
+import com.yahoo.component.ComponentId;
import com.yahoo.component.ComponentSpecification;
import com.yahoo.component.Vtag;
import com.yahoo.component.chain.Chain;
@@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric;
import com.yahoo.jdisc.Request;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
+import com.yahoo.language.provider.DefaultEmbedderProvider;
import com.yahoo.net.HostName;
import com.yahoo.net.UriTools;
import com.yahoo.prelude.query.parser.ParseException;
@@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
* Handles search request.
@@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler {
private final Optional<String> hostResponseHeaderKey;
private final String selfHostname = HostName.getLocalhost();
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final ExecutionFactory executionFactory;
private final AtomicLong numRequestsLeftToTrace;
@@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler {
ContainerThreadPool threadpool,
CompiledQueryProfileRegistry queryProfileRegistry,
ContainerHttpConfig config,
- Embedder embedder,
+ ComponentRegistry<Embedder> embedders,
ExecutionFactory executionFactory,
ZoneInfo zoneInfo) {
- this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory,
+ this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory,
config.numQueriesToTraceOnDebugAfterConstruction(),
config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()),
zoneInfo);
@@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler {
CompiledQueryProfileRegistry queryProfileRegistry,
ExecutionFactory executionFactory,
Optional<String> hostResponseHeaderKey) {
- this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse,
+ this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse),
executionFactory, 0, hostResponseHeaderKey,
ZoneInfo.defaultInfo());
}
@@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler {
long numQueriesToTraceOnDebugAfterStartup,
Optional<String> hostResponseHeaderKey,
ZoneInfo zoneInfo) {
+ this(metric, executor, queryProfileRegistry, toRegistry(embedder),
+ executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey,
+ ZoneInfo.defaultInfo());
+ }
+
+ private SearchHandler(Metric metric,
+ Executor executor,
+ CompiledQueryProfileRegistry queryProfileRegistry,
+ ComponentRegistry<Embedder> embedders,
+ ExecutionFactory executionFactory,
+ long numQueriesToTraceOnDebugAfterStartup,
+ Optional<String> hostResponseHeaderKey,
+ ZoneInfo zoneInfo) {
super(executor, metric, true);
+
log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this));
this.queryProfileRegistry = queryProfileRegistry;
- this.embedder = embedder;
+ this.embedders = toMap(embedders);
this.executionFactory = executionFactory;
this.maxThreads = examineExecutor(executor);
@@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler {
Query query = new Query.Builder().setRequest(request)
.setRequestMap(requestMap)
.setQueryProfile(queryProfile)
- .setEmbedder(embedder)
+ .setEmbedders(embedders)
.setZoneInfo(zoneInfo)
.build();
@@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler {
.build();
}
+ private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
+ var map = embedders.allComponentsById().entrySet().stream()
+ .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
+ if (map.size() > 1) {
+ map.remove(DefaultEmbedderProvider.class.getName());
+ // Ideally, this should be handled by dependency injection, however for now this workaround is necessary.
+ }
+ return Collections.unmodifiableMap(map);
+ }
+
+ private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) {
+ ComponentRegistry<Embedder> emb = new ComponentRegistry<>();
+ emb.register(new ComponentId(Embedder.defaultEmbedderId), embedder);
+ return emb;
+ }
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
index f58395fd5bb..5b3758f103d 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
@@ -30,7 +30,7 @@ import java.util.Map;
public class QueryProfileProperties extends Properties {
private final CompiledQueryProfile profile;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
// Note: The priority order is: values has precedence over references
@@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties {
private List<Pair<CompoundName, CompiledQueryProfile>> references = null;
public QueryProfileProperties(CompiledQueryProfile profile) {
- this(profile, Embedder.throwsOnUse);
+ this(profile, Embedder.throwsOnUse.asMap());
}
- /** Creates an instance from a profile, throws an exception if the given profile is null */
public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) {
+ this(profile, Map.of(Embedder.defaultEmbedderId, embedder));
+ }
+
+ /** Creates an instance from a profile, throws an exception if the given profile is null */
+ public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) {
Validator.ensureNotNull("The profile wrapped by this cannot be null", profile);
this.profile = profile;
- this.embedder = embedder;
+ this.embedders = embedders;
}
/** Returns the query profile backing this, or null if none */
@@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties {
if (fieldDescription != null) {
if (i == name.size() - 1) { // at the end of the path, check the assignment type
- var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context);
+ var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context);
var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext);
if (convertedValue == null
&& fieldDescription.getType() instanceof QueryProfileFieldType
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
index 8dfb67a9d5f..5c449b9645a 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
@@ -14,14 +14,20 @@ public class ConversionContext {
private final String destination;
private final CompiledQueryProfileRegistry registry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final Language language;
public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder,
Map<String, String> context) {
+ this(destination, registry, Map.of(Embedder.defaultEmbedderId, embedder), context);
+ }
+
+ public ConversionContext(String destination, CompiledQueryProfileRegistry registry,
+ Map<String, Embedder> embedders,
+ Map<String, String> context) {
this.destination = destination;
this.registry = registry;
- this.embedder = embedder;
+ this.embedders = embedders;
this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language"))
: Language.UNKNOWN;
}
@@ -33,14 +39,14 @@ public class ConversionContext {
CompiledQueryProfileRegistry registry() {return registry;}
/** Returns the configured embedder, never null */
- Embedder embedder() { return embedder; }
+ Map<String, Embedder> embedders() { return embedders; }
/** Returns the language, which is never null but may be UNKNOWN */
Language language() { return language; }
/** Returns an empty context */
public static ConversionContext empty() {
- return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of());
+ return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of());
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index d6676db3774..6f1cfccc16b 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
/**
* A tensor field type in a query profile
*
@@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType;
*/
public class TensorFieldType extends FieldType {
+ private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])");
+
private final TensorType type;
/** Creates a tensor field type with information about the kind of tensor this will hold */
@@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType {
private Tensor encode(String s, ConversionContext context) {
if ( ! s.endsWith(")"))
throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
- String text = s.substring("embed(".length(), s.length() - 1);
- return context.embedder().embed(text, toEmbedderContext(context), type);
+ String argument = s.substring("embed(".length(), s.length() - 1);
+ Embedder embedder;
+
+ // Check if arguments specifies an embedder with the format embed(embedder, "text to encode")
+ Matcher matcher = embedderArgumentRegexp.matcher(argument);
+ if (matcher.matches()) {
+ String embedderId = matcher.group(1);
+ argument = matcher.group(2);
+ if (!context.embedders().containsKey(embedderId)) {
+ throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ }
+ embedder = context.embedders().get(embedderId);
+ } else if (context.embedders().size() == 0) {
+ throw new IllegalStateException("No embedders provided"); // should never happen
+ } else if (context.embedders().size() > 1) {
+ throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ } else {
+ embedder = context.embedders().entrySet().stream().findFirst().get().getValue();
+ }
+
+ return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type);
+ }
+
+ private static String removeQuotes(String s) {
+ if (s.startsWith("'") && s.endsWith("'")) {
+ return s.substring(1, s.length() - 1);
+ }
+ if (s.startsWith("\"") && s.endsWith("\"")) {
+ return s.substring(1, s.length() - 1);
+ }
+ return s;
+ }
+
+ private static String validEmbedders(Map<String, Embedder> embedders) {
+ List<String> embedderIds = new ArrayList<>();
+ embedders.forEach((key, value) -> embedderIds.add(key));
+ embedderIds.sort(null);
+ return String.join(",", embedderIds);
}
private Embedder.Context toEmbedderContext(ConversionContext context) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index 243915662d2..98b65c6edd9 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -34,12 +34,16 @@ public class QueryProperties extends Properties {
private Query query;
private final CompiledQueryProfileRegistry profileRegistry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) {
+ this(query, profileRegistry, Map.of(Embedder.defaultEmbedderId, embedder));
+ }
+
+ public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) {
this.query = query;
this.profileRegistry = profileRegistry;
- this.embedder = embedder;
+ this.embedders = embedders;
}
public void setParentQuery(Query query) {
@@ -394,7 +398,7 @@ public class QueryProperties extends Properties {
if (type == null) return value; // no type info -> keep as string
FieldDescription field = type.getField(key);
if (field == null) return value; // ditto
- return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context));
+ return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context));
}
private void throwIllegalParameter(String key,String namespace) {
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
index 2e88c9fd0a4..a1556aac189 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
@@ -25,9 +25,11 @@ import org.junit.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -441,42 +443,42 @@ public class QueryProfileTypeTestCase {
@Test
public void testUnembeddedTensorRankFeatureInRequest() {
- QueryProfile profile = new QueryProfile("test");
- profile.setType(testtype);
- registry.register(profile);
-
- CompiledQueryProfileRegistry cRegistry = registry.compile();
- String textToEmbed = "text to embed into a tensor";
- String destinationFeature = "query(myTensor4)";
- Tensor expectedTensor = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]");
- Query query1 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) +
- "=" + urlEncode("embed(" + textToEmbed + ")"),
- com.yahoo.jdisc.http.HttpRequest.Method.GET))
- .setQueryProfile(cRegistry.getComponent("test"))
- .setEmbedder(new MockEmbedder(textToEmbed,
- Language.UNKNOWN,
- destinationFeature,
- expectedTensor))
- .build();
- assertEquals(0, query1.errors().size());
- assertEquals(expectedTensor, query1.properties().get("ranking.features.query(myTensor4)"));
- assertEquals(expectedTensor, query1.getRanking().getFeatures().getTensor("query(myTensor4)").get());
-
- // Explicit language
- Query query2 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) +
- "=" + urlEncode("embed(" + textToEmbed + ")") +
- "&language=en",
- com.yahoo.jdisc.http.HttpRequest.Method.GET))
- .setQueryProfile(cRegistry.getComponent("test"))
- .setEmbedder(new MockEmbedder(textToEmbed,
- Language.ENGLISH,
- destinationFeature,
- expectedTensor))
- .build();
- assertEquals(0, query2.errors().size());
- assertEquals(expectedTensor, query2.properties().get("ranking.features.query(myTensor4)"));
- assertEquals(expectedTensor, query2.getRanking().getFeatures().getTensor("query(myTensor4)").get());
-
+ String text = "text to embed into a tensor";
+ Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]");
+ Tensor embedding2 = Tensor.from("tensor<float>(x[5]):[1,2,3,4,0]]");
+
+ Map<String, Embedder> embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1)
+ );
+ assertEmbedQuery("embed(" + text + ")", embedding1, embedders);
+ assertEmbedQuery("embed('" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(\"" + text + "\")", embedding1, embedders);
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(emb1, \"" + text + "\")", embedding1, embedders);
+ assertEmbedQueryFails("embed(emb2, \"" + text + "\")", embedding1, embedders,
+ "Can't find embedder 'emb2'. Valid embedders are emb1");
+
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1),
+ "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2)
+ );
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders);
+ assertEmbedQueryFails("embed(emb3, \"" + text + "\")", embedding1, embedders,
+ "Can't find embedder 'emb3'. Valid embedders are emb1,emb2");
+
+ // And with specified language
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1)
+ );
+ assertEmbedQuery("embed(" + text + ")", embedding1, embedders, Language.ENGLISH.languageCode());
+
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1),
+ "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2)
+ );
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders, Language.ENGLISH.languageCode());
+ assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode());
}
private String urlEncode(String s) {
@@ -729,20 +731,52 @@ public class QueryProfileTypeTestCase {
}
}
+ private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) {
+ assertEmbedQuery(embed, expected, embedders, null);
+ }
+
+ private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) {
+ QueryProfile profile = new QueryProfile("test");
+ profile.setType(testtype);
+ registry.register(profile);
+ CompiledQueryProfileRegistry cRegistry = registry.compile();
+
+ String languageParam = language == null ? "" : "&language=" + language;
+ String destination = "query(myTensor4)";
+
+ Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest(
+ "?" + urlEncode("ranking.features." + destination) +
+ "=" + urlEncode(embed) +
+ languageParam,
+ com.yahoo.jdisc.http.HttpRequest.Method.GET))
+ .setQueryProfile(cRegistry.getComponent("test"))
+ .setEmbedders(embedders)
+ .build();
+ assertEquals(0, query.errors().size());
+ assertEquals(expected, query.properties().get("ranking.features." + destination));
+ assertEquals(expected, query.getRanking().getFeatures().getTensor(destination).get());
+ }
+
+ private void assertEmbedQueryFails(String embed, Tensor expected, Map<String, Embedder> embedders, String errMsg) {
+ Throwable t = assertThrows(IllegalArgumentException.class, () -> assertEmbedQuery(embed, expected, embedders));
+ while (t != null) {
+ if (t.getMessage().equals(errMsg)) return;
+ t = t.getCause();
+ }
+ fail("Error '" + errMsg + "' not thrown");
+ }
+
private static final class MockEmbedder implements Embedder {
private final String expectedText;
private final Language expectedLanguage;
- private final String expectedDestination;
private final Tensor tensorToReturn;
public MockEmbedder(String expectedText,
Language expectedLanguage,
- String expectedDestination,
Tensor tensorToReturn) {
this.expectedText = expectedText;
this.expectedLanguage = expectedLanguage;
- this.expectedDestination = expectedDestination;
this.tensorToReturn = tensorToReturn;
}
@@ -756,7 +790,6 @@ public class QueryProfileTypeTestCase {
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
assertEquals(expectedText, text);
assertEquals(expectedLanguage, context.getLanguage());
- assertEquals(expectedDestination, context.getDestination());
assertEquals(tensorToReturn.type(), tensorType);
return tensorToReturn;
}