diff options
author | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:42 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:42 +0100 |
commit | 8a64a50ac9f0cbea18f0c1a8e1ef482d3311e873 (patch) | |
tree | 19853dccb4c68714885e7cb73d32e1f191ef306a | |
parent | c5e464f1a6da3a74113d775805187a547074a2da (diff) |
Add embedder selection argument to query parameter transformation
8 files changed, 215 insertions, 72 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 8ac5aaa127d..436a8ca18e4 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -1868,7 +1868,9 @@ "public com.yahoo.search.Query$Builder setQueryProfile(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)", "public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()", "public com.yahoo.search.Query$Builder setEmbedder(com.yahoo.language.process.Embedder)", + "public com.yahoo.search.Query$Builder setEmbedders(java.util.Map)", "public com.yahoo.language.process.Embedder getEmbedder()", + "public java.util.Map getEmbedders()", "public com.yahoo.search.Query$Builder setZoneInfo(ai.vespa.cloud.ZoneInfo)", "public ai.vespa.cloud.ZoneInfo getZoneInfo()", "public com.yahoo.search.Query build()" @@ -4351,7 +4353,7 @@ "public" ], "methods": [ - "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)", + "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.component.provider.ComponentRegistry, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)", "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)", "public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)", "public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.container.logging.AccessLog, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.search.searchchain.ExecutionFactory)", @@ -5990,6 +5992,7 @@ "methods": [ "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)", "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, com.yahoo.language.process.Embedder)", + "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, java.util.Map)", "public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()", "public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)", "public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)", @@ -6365,6 +6368,7 @@ ], "methods": [ "public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder, java.util.Map)", + "public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map, java.util.Map)", "public java.lang.String destination()", "public static com.yahoo.search.query.profile.types.ConversionContext empty()" ], @@ -6639,6 +6643,7 @@ ], "methods": [ "public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder)", + "public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map)", "public void setParentQuery(com.yahoo.search.Query)", "public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)", "public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)", diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index 83fa18d847f..b7a0fcb5dc3 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -337,7 +337,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, Embedder.throwsOnUse, ZoneInfo.defaultInfo()); + init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo()); } // TODO: Deprecate most constructors above here @@ -346,31 +346,31 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { this(builder.getRequest(), builder.getRequestMap(), builder.getQueryProfile(), - builder.getEmbedder(), + builder.getEmbedders(), builder.getZoneInfo()); } - private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Embedder embedder, + private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders, ZoneInfo zoneInfo) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, embedder, zoneInfo); + init(requestMap, queryProfile, embedders, zoneInfo); } private void init(Map<String, String> requestMap, CompiledQueryProfile queryProfile, - Embedder embedder, + Map<String, Embedder> embedders, ZoneInfo zoneInfo) { startTime = httpRequest.getJDiscRequest().creationTime(TimeUnit.MILLISECONDS); if (queryProfile != null) { // Move all request parameters to the query profile just to validate that the parameter settings are legal - Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedder); + Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedders); properties().chain(queryProfileProperties); // TODO: Just checking legality rather than actually setting would be faster setPropertiesFromRequestMap(requestMap, properties(), true); // Adds errors to the query for illegal set attempts // Create the full chain - properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedder)). + properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedders)). chain(new ModelObjectMap()). chain(new RequestContextProperties(requestMap, zoneInfo)). chain(queryProfileProperties). @@ -389,7 +389,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { } else { // bypass these complications if there is no query profile to get values from and validate against properties(). - chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedder)). + chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedders)). chain(new PropertyMap()). chain(new DefaultProperties()); setPropertiesFromRequestMap(requestMap, properties(), false); @@ -1131,7 +1131,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { private HttpRequest request = null; private Map<String, String> requestMap = null; private CompiledQueryProfile queryProfile = null; - private Embedder embedder = Embedder.throwsOnUse; + private Map<String, Embedder> embedders = Embedder.throwsOnUse.asMap(); private ZoneInfo zoneInfo = ZoneInfo.defaultInfo(); public Builder setRequest(String query) { @@ -1171,11 +1171,22 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public CompiledQueryProfile getQueryProfile() { return queryProfile; } public Builder setEmbedder(Embedder embedder) { - this.embedder = embedder; + return setEmbedders(Map.of(Embedder.defaultEmbedderName, embedder)); + } + + public Builder setEmbedders(Map<String, Embedder> embedders) { + this.embedders = embedders; return this; } - public Embedder getEmbedder() { return embedder; } + public Embedder getEmbedder() { + if (embedders.size() != 1) { + throw new IllegalArgumentException("Attempt to get single embedder but multiple exists."); + } + return embedders.entrySet().stream().findFirst().get().getValue(); + } + + public Map<String, Embedder> getEmbedders() { return embedders; } public Builder setZoneInfo(ZoneInfo zoneInfo) { this.zoneInfo = zoneInfo; diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java index b65953935f0..86ff34f659b 100644 --- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java +++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java @@ -3,6 +3,7 @@ package com.yahoo.search.handler; import com.google.inject.Inject; import com.yahoo.collections.Tuple2; +import com.yahoo.component.ComponentId; import com.yahoo.component.ComponentSpecification; import com.yahoo.component.Vtag; import com.yahoo.component.chain.Chain; @@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.Request; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.provider.DefaultEmbedderProvider; import com.yahoo.net.HostName; import com.yahoo.net.UriTools; import com.yahoo.prelude.query.parser.ParseException; @@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Handles search request. @@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler { private final Optional<String> hostResponseHeaderKey; private final String selfHostname = HostName.getLocalhost(); - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final ExecutionFactory executionFactory; private final AtomicLong numRequestsLeftToTrace; @@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler { ContainerThreadPool threadpool, CompiledQueryProfileRegistry queryProfileRegistry, ContainerHttpConfig config, - Embedder embedder, + ComponentRegistry<Embedder> embedders, ExecutionFactory executionFactory, ZoneInfo zoneInfo) { - this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory, + this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory, config.numQueriesToTraceOnDebugAfterConstruction(), config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()), zoneInfo); @@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler { CompiledQueryProfileRegistry queryProfileRegistry, ExecutionFactory executionFactory, Optional<String> hostResponseHeaderKey) { - this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse, + this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse), executionFactory, 0, hostResponseHeaderKey, ZoneInfo.defaultInfo()); } @@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler { long numQueriesToTraceOnDebugAfterStartup, Optional<String> hostResponseHeaderKey, ZoneInfo zoneInfo) { + this(metric, executor, queryProfileRegistry, toRegistry(embedder), + executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey, + ZoneInfo.defaultInfo()); + } + + private SearchHandler(Metric metric, + Executor executor, + CompiledQueryProfileRegistry queryProfileRegistry, + ComponentRegistry<Embedder> embedders, + ExecutionFactory executionFactory, + long numQueriesToTraceOnDebugAfterStartup, + Optional<String> hostResponseHeaderKey, + ZoneInfo zoneInfo) { super(executor, metric, true); + log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this)); this.queryProfileRegistry = queryProfileRegistry; - this.embedder = embedder; + this.embedders = toMap(embedders); this.executionFactory = executionFactory; this.maxThreads = examineExecutor(executor); @@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler { Query query = new Query.Builder().setRequest(request) .setRequestMap(requestMap) .setQueryProfile(queryProfile) - .setEmbedder(embedder) + .setEmbedders(embedders) .setZoneInfo(zoneInfo) .build(); @@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler { .build(); } + private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) { + var map = embedders.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + if (map.size() > 1) { + map.remove(DefaultEmbedderProvider.class.getName()); + // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. + } + return Collections.unmodifiableMap(map); + } + + private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) { + ComponentRegistry<Embedder> emb = new ComponentRegistry<>(); + emb.register(new ComponentId(Embedder.defaultEmbedderName), embedder); + return emb; + } + } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java index f58395fd5bb..6e778a0fac6 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java @@ -30,7 +30,7 @@ import java.util.Map; public class QueryProfileProperties extends Properties { private final CompiledQueryProfile profile; - private final Embedder embedder; + private final Map<String, Embedder> embedders; // Note: The priority order is: values has precedence over references @@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties { private List<Pair<CompoundName, CompiledQueryProfile>> references = null; public QueryProfileProperties(CompiledQueryProfile profile) { - this(profile, Embedder.throwsOnUse); + this(profile, Embedder.throwsOnUse.asMap()); } - /** Creates an instance from a profile, throws an exception if the given profile is null */ public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) { + this(profile, Map.of(Embedder.defaultEmbedderName, embedder)); + } + + /** Creates an instance from a profile, throws an exception if the given profile is null */ + public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) { Validator.ensureNotNull("The profile wrapped by this cannot be null", profile); this.profile = profile; - this.embedder = embedder; + this.embedders = embedders; } /** Returns the query profile backing this, or null if none */ @@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties { if (fieldDescription != null) { if (i == name.size() - 1) { // at the end of the path, check the assignment type - var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context); + var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context); var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext); if (convertedValue == null && fieldDescription.getType() instanceof QueryProfileFieldType diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index 8dfb67a9d5f..1fc405051ac 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -14,14 +14,20 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, Map<String, String> context) { + this(destination, registry, Map.of(Embedder.defaultEmbedderName, embedder), context); + } + + public ConversionContext(String destination, CompiledQueryProfileRegistry registry, + Map<String, Embedder> embedders, + Map<String, String> context) { this.destination = destination; this.registry = registry; - this.embedder = embedder; + this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; } @@ -33,14 +39,14 @@ public class ConversionContext { CompiledQueryProfileRegistry registry() {return registry;} /** Returns the configured embedder, never null */ - Embedder embedder() { return embedder; } + Map<String, Embedder> embedders() { return embedders; } /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } /** Returns an empty context */ public static ConversionContext empty() { - return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of()); + return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index d6676db3774..6f1cfccc16b 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + /** * A tensor field type in a query profile * @@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType; */ public class TensorFieldType extends FieldType { + private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private final TensorType type; /** Creates a tensor field type with information about the kind of tensor this will hold */ @@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType { private Tensor encode(String s, ConversionContext context) { if ( ! s.endsWith(")")) throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); - String text = s.substring("embed(".length(), s.length() - 1); - return context.embedder().embed(text, toEmbedderContext(context), type); + String argument = s.substring("embed(".length(), s.length() - 1); + Embedder embedder; + + // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") + Matcher matcher = embedderArgumentRegexp.matcher(argument); + if (matcher.matches()) { + String embedderId = matcher.group(1); + argument = matcher.group(2); + if (!context.embedders().containsKey(embedderId)) { + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } + embedder = context.embedders().get(embedderId); + } else if (context.embedders().size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } else if (context.embedders().size() > 1) { + throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } else { + embedder = context.embedders().entrySet().stream().findFirst().get().getValue(); + } + + return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type); + } + + private static String removeQuotes(String s) { + if (s.startsWith("'") && s.endsWith("'")) { + return s.substring(1, s.length() - 1); + } + if (s.startsWith("\"") && s.endsWith("\"")) { + return s.substring(1, s.length() - 1); + } + return s; + } + + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); } private Embedder.Context toEmbedderContext(ConversionContext context) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 243915662d2..dc901589cde 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -34,12 +34,16 @@ public class QueryProperties extends Properties { private Query query; private final CompiledQueryProfileRegistry profileRegistry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) { + this(query, profileRegistry, Map.of(Embedder.defaultEmbedderName, embedder)); + } + + public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) { this.query = query; this.profileRegistry = profileRegistry; - this.embedder = embedder; + this.embedders = embedders; } public void setParentQuery(Query query) { @@ -394,7 +398,7 @@ public class QueryProperties extends Properties { if (type == null) return value; // no type info -> keep as string FieldDescription field = type.getField(key); if (field == null) return value; // ditto - return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context)); + return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context)); } private void throwIllegalParameter(String key,String namespace) { diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index 2e88c9fd0a4..a1556aac189 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -25,9 +25,11 @@ import org.junit.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -441,42 +443,42 @@ public class QueryProfileTypeTestCase { @Test public void testUnembeddedTensorRankFeatureInRequest() { - QueryProfile profile = new QueryProfile("test"); - profile.setType(testtype); - registry.register(profile); - - CompiledQueryProfileRegistry cRegistry = registry.compile(); - String textToEmbed = "text to embed into a tensor"; - String destinationFeature = "query(myTensor4)"; - Tensor expectedTensor = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); - Query query1 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) + - "=" + urlEncode("embed(" + textToEmbed + ")"), - com.yahoo.jdisc.http.HttpRequest.Method.GET)) - .setQueryProfile(cRegistry.getComponent("test")) - .setEmbedder(new MockEmbedder(textToEmbed, - Language.UNKNOWN, - destinationFeature, - expectedTensor)) - .build(); - assertEquals(0, query1.errors().size()); - assertEquals(expectedTensor, query1.properties().get("ranking.features.query(myTensor4)")); - assertEquals(expectedTensor, query1.getRanking().getFeatures().getTensor("query(myTensor4)").get()); - - // Explicit language - Query query2 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) + - "=" + urlEncode("embed(" + textToEmbed + ")") + - "&language=en", - com.yahoo.jdisc.http.HttpRequest.Method.GET)) - .setQueryProfile(cRegistry.getComponent("test")) - .setEmbedder(new MockEmbedder(textToEmbed, - Language.ENGLISH, - destinationFeature, - expectedTensor)) - .build(); - assertEquals(0, query2.errors().size()); - assertEquals(expectedTensor, query2.properties().get("ranking.features.query(myTensor4)")); - assertEquals(expectedTensor, query2.getRanking().getFeatures().getTensor("query(myTensor4)").get()); - + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + Tensor embedding2 = Tensor.from("tensor<float>(x[5]):[1,2,3,4,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders); + assertEmbedQuery("embed('" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(\"" + text + "\")", embedding1, embedders); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb1, \"" + text + "\")", embedding1, embedders); + assertEmbedQueryFails("embed(emb2, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb2'. Valid embedders are emb1"); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders); + assertEmbedQueryFails("embed(emb3, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); + + // And with specified language + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders, Language.ENGLISH.languageCode()); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders, Language.ENGLISH.languageCode()); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); } private String urlEncode(String s) { @@ -729,20 +731,52 @@ public class QueryProfileTypeTestCase { } } + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { + assertEmbedQuery(embed, expected, embedders, null); + } + + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + QueryProfile profile = new QueryProfile("test"); + profile.setType(testtype); + registry.register(profile); + CompiledQueryProfileRegistry cRegistry = registry.compile(); + + String languageParam = language == null ? "" : "&language=" + language; + String destination = "query(myTensor4)"; + + Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( + "?" + urlEncode("ranking.features." + destination) + + "=" + urlEncode(embed) + + languageParam, + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .setQueryProfile(cRegistry.getComponent("test")) + .setEmbedders(embedders) + .build(); + assertEquals(0, query.errors().size()); + assertEquals(expected, query.properties().get("ranking.features." + destination)); + assertEquals(expected, query.getRanking().getFeatures().getTensor(destination).get()); + } + + private void assertEmbedQueryFails(String embed, Tensor expected, Map<String, Embedder> embedders, String errMsg) { + Throwable t = assertThrows(IllegalArgumentException.class, () -> assertEmbedQuery(embed, expected, embedders)); + while (t != null) { + if (t.getMessage().equals(errMsg)) return; + t = t.getCause(); + } + fail("Error '" + errMsg + "' not thrown"); + } + private static final class MockEmbedder implements Embedder { private final String expectedText; private final Language expectedLanguage; - private final String expectedDestination; private final Tensor tensorToReturn; public MockEmbedder(String expectedText, Language expectedLanguage, - String expectedDestination, Tensor tensorToReturn) { this.expectedText = expectedText; this.expectedLanguage = expectedLanguage; - this.expectedDestination = expectedDestination; this.tensorToReturn = tensorToReturn; } @@ -756,7 +790,6 @@ public class QueryProfileTypeTestCase { public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedText, text); assertEquals(expectedLanguage, context.getLanguage()); - assertEquals(expectedDestination, context.getDestination()); assertEquals(tensorToReturn.type(), tensorType); return tensorToReturn; } |