diff options
29 files changed, 432 insertions, 156 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java index 49ae00d0663..256c628a1cb 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java @@ -415,14 +415,14 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, return wasConfiguredToDoAttributing; } - /** Parse an indexing expression which will use the simple linguistics implementatino suitable for testing */ + /** Parse an indexing expression which will use the simple linguistics implementation suitable for testing */ public void parseIndexingScript(String script) { - parseIndexingScript(script, new SimpleLinguistics(), Embedder.throwsOnUse); + parseIndexingScript(script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public void parseIndexingScript(String script, Linguistics linguistics, Embedder embedder) { + public void parseIndexingScript(String script, Linguistics linguistics, Map<String, Embedder> embedders) { try { - ScriptParserContext config = new ScriptParserContext(linguistics, embedder); + ScriptParserContext config = new ScriptParserContext(linguistics, embedders); config.setInputStream(new IndexingInput(script)); setIndexingScript(ScriptExpression.newInstance(config)); } catch (ParseException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java index a5f5f961ab5..cdd3cc386a4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java @@ -13,6 +13,8 @@ import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.yolean.Exceptions; +import java.util.Map; + /** * @author Einar M R Rosenvinge */ @@ -32,13 +34,13 @@ public class IndexingOperation implements FieldOperation { /** Creates an indexing operation which will use the simple linguistics implementation suitable for testing */ public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine) throws ParseException { - return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine, - Linguistics linguistics, Embedder embedder) + Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { - ScriptParserContext config = new ScriptParserContext(linguistics, embedder); + ScriptParserContext config = new ScriptParserContext(linguistics, embedders); config.setAnnotatorConfig(new AnnotatorConfig()); config.setInputStream(input); ScriptExpression exp; diff --git a/config-model/src/main/javacc/IntermediateParser.jj b/config-model/src/main/javacc/IntermediateParser.jj index ba955f071b2..8a4798d6f74 100644 --- a/config-model/src/main/javacc/IntermediateParser.jj +++ b/config-model/src/main/javacc/IntermediateParser.jj @@ -81,7 +81,7 @@ public class IntermediateParser { */ @SuppressWarnings("deprecation") private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException { - return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse); + return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } /** @@ -90,13 +90,13 @@ public class IntermediateParser { * @param multiline Whether or not to allow multi-line expressions. * @param linguistics What to use for tokenizing. */ - private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Embedder embedder) throws ParseException { + private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { SimpleCharStream input = (SimpleCharStream)token_source.input_stream; if (token.next != null) { input.backup(token.next.image.length()); } try { - return IndexingOperation.fromStream(input, multiline, linguistics, embedder); + return IndexingOperation.fromStream(input, multiline, linguistics, embedders); } finally { token.next = null; jj_ntk = -1; diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj index ab0cdefc355..aeffe6e5c39 100644 --- a/config-model/src/main/javacc/SDParser.jj +++ b/config-model/src/main/javacc/SDParser.jj @@ -112,7 +112,7 @@ public class SDParser { */ @SuppressWarnings("deprecation") private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException { - return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse); + return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } /** @@ -121,13 +121,13 @@ public class SDParser { * @param multiline Whether or not to allow multi-line expressions. * @param linguistics What to use for tokenizing. */ - private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Embedder embedder) throws ParseException { + private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { SimpleCharStream input = (SimpleCharStream)token_source.input_stream; if (token.next != null) { input.backup(token.next.image.length()); } try { - return IndexingOperation.fromStream(input, multiline, linguistics, embedder); + return IndexingOperation.fromStream(input, multiline, linguistics, embedders); } finally { token.next = null; jj_ntk = -1; diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 5b990ca8758..b7aa1a8d0ef 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -1868,7 +1868,9 @@ "public com.yahoo.search.Query$Builder setQueryProfile(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)", "public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()", "public com.yahoo.search.Query$Builder setEmbedder(com.yahoo.language.process.Embedder)", + "public com.yahoo.search.Query$Builder setEmbedders(java.util.Map)", "public com.yahoo.language.process.Embedder getEmbedder()", + "public java.util.Map getEmbedders()", "public com.yahoo.search.Query$Builder setZoneInfo(ai.vespa.cloud.ZoneInfo)", "public ai.vespa.cloud.ZoneInfo getZoneInfo()", "public com.yahoo.search.Query build()" @@ -4354,7 +4356,7 @@ "public" ], "methods": [ - "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)", + "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.component.provider.ComponentRegistry, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)", "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)", "public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)", "public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.container.logging.AccessLog, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.search.searchchain.ExecutionFactory)", @@ -5993,6 +5995,7 @@ "methods": [ "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)", "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, com.yahoo.language.process.Embedder)", + "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, java.util.Map)", "public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()", "public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)", "public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)", @@ -6368,6 +6371,7 @@ ], "methods": [ "public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder, java.util.Map)", + "public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map, java.util.Map)", "public java.lang.String destination()", "public static com.yahoo.search.query.profile.types.ConversionContext empty()" ], @@ -6642,6 +6646,7 @@ ], "methods": [ "public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder)", + "public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map)", "public void setParentQuery(com.yahoo.search.Query)", "public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)", "public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)", diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index 83fa18d847f..553c73dac17 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -337,7 +337,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, Embedder.throwsOnUse, ZoneInfo.defaultInfo()); + init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo()); } // TODO: Deprecate most constructors above here @@ -346,31 +346,31 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { this(builder.getRequest(), builder.getRequestMap(), builder.getQueryProfile(), - builder.getEmbedder(), + builder.getEmbedders(), builder.getZoneInfo()); } - private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Embedder embedder, + private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders, ZoneInfo zoneInfo) { super(new QueryPropertyAliases(propertyAliases)); this.httpRequest = request; - init(requestMap, queryProfile, embedder, zoneInfo); + init(requestMap, queryProfile, embedders, zoneInfo); } private void init(Map<String, String> requestMap, CompiledQueryProfile queryProfile, - Embedder embedder, + Map<String, Embedder> embedders, ZoneInfo zoneInfo) { startTime = httpRequest.getJDiscRequest().creationTime(TimeUnit.MILLISECONDS); if (queryProfile != null) { // Move all request parameters to the query profile just to validate that the parameter settings are legal - Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedder); + Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedders); properties().chain(queryProfileProperties); // TODO: Just checking legality rather than actually setting would be faster setPropertiesFromRequestMap(requestMap, properties(), true); // Adds errors to the query for illegal set attempts // Create the full chain - properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedder)). + properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedders)). chain(new ModelObjectMap()). chain(new RequestContextProperties(requestMap, zoneInfo)). chain(queryProfileProperties). @@ -389,7 +389,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { } else { // bypass these complications if there is no query profile to get values from and validate against properties(). - chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedder)). + chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedders)). chain(new PropertyMap()). chain(new DefaultProperties()); setPropertiesFromRequestMap(requestMap, properties(), false); @@ -1131,7 +1131,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { private HttpRequest request = null; private Map<String, String> requestMap = null; private CompiledQueryProfile queryProfile = null; - private Embedder embedder = Embedder.throwsOnUse; + private Map<String, Embedder> embedders = Embedder.throwsOnUse.asMap(); private ZoneInfo zoneInfo = ZoneInfo.defaultInfo(); public Builder setRequest(String query) { @@ -1171,11 +1171,22 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { public CompiledQueryProfile getQueryProfile() { return queryProfile; } public Builder setEmbedder(Embedder embedder) { - this.embedder = embedder; + return setEmbedders(Map.of(Embedder.defaultEmbedderId, embedder)); + } + + public Builder setEmbedders(Map<String, Embedder> embedders) { + this.embedders = embedders; return this; } - public Embedder getEmbedder() { return embedder; } + public Embedder getEmbedder() { + if (embedders.size() != 1) { + throw new IllegalArgumentException("Attempt to get single embedder but multiple exists."); + } + return embedders.entrySet().stream().findFirst().get().getValue(); + } + + public Map<String, Embedder> getEmbedders() { return embedders; } public Builder setZoneInfo(ZoneInfo zoneInfo) { this.zoneInfo = zoneInfo; diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java index b65953935f0..af6374ba245 100644 --- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java +++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java @@ -3,6 +3,7 @@ package com.yahoo.search.handler; import com.google.inject.Inject; import com.yahoo.collections.Tuple2; +import com.yahoo.component.ComponentId; import com.yahoo.component.ComponentSpecification; import com.yahoo.component.Vtag; import com.yahoo.component.chain.Chain; @@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.Request; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.provider.DefaultEmbedderProvider; import com.yahoo.net.HostName; import com.yahoo.net.UriTools; import com.yahoo.prelude.query.parser.ParseException; @@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Handles search request. @@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler { private final Optional<String> hostResponseHeaderKey; private final String selfHostname = HostName.getLocalhost(); - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final ExecutionFactory executionFactory; private final AtomicLong numRequestsLeftToTrace; @@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler { ContainerThreadPool threadpool, CompiledQueryProfileRegistry queryProfileRegistry, ContainerHttpConfig config, - Embedder embedder, + ComponentRegistry<Embedder> embedders, ExecutionFactory executionFactory, ZoneInfo zoneInfo) { - this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory, + this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory, config.numQueriesToTraceOnDebugAfterConstruction(), config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()), zoneInfo); @@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler { CompiledQueryProfileRegistry queryProfileRegistry, ExecutionFactory executionFactory, Optional<String> hostResponseHeaderKey) { - this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse, + this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse), executionFactory, 0, hostResponseHeaderKey, ZoneInfo.defaultInfo()); } @@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler { long numQueriesToTraceOnDebugAfterStartup, Optional<String> hostResponseHeaderKey, ZoneInfo zoneInfo) { + this(metric, executor, queryProfileRegistry, toRegistry(embedder), + executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey, + ZoneInfo.defaultInfo()); + } + + private SearchHandler(Metric metric, + Executor executor, + CompiledQueryProfileRegistry queryProfileRegistry, + ComponentRegistry<Embedder> embedders, + ExecutionFactory executionFactory, + long numQueriesToTraceOnDebugAfterStartup, + Optional<String> hostResponseHeaderKey, + ZoneInfo zoneInfo) { super(executor, metric, true); + log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this)); this.queryProfileRegistry = queryProfileRegistry; - this.embedder = embedder; + this.embedders = toMap(embedders); this.executionFactory = executionFactory; this.maxThreads = examineExecutor(executor); @@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler { Query query = new Query.Builder().setRequest(request) .setRequestMap(requestMap) .setQueryProfile(queryProfile) - .setEmbedder(embedder) + .setEmbedders(embedders) .setZoneInfo(zoneInfo) .build(); @@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler { .build(); } + private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) { + var map = embedders.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + if (map.size() > 1) { + map.remove(DefaultEmbedderProvider.class.getName()); + // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. + } + return Collections.unmodifiableMap(map); + } + + private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) { + ComponentRegistry<Embedder> emb = new ComponentRegistry<>(); + emb.register(new ComponentId(Embedder.defaultEmbedderId), embedder); + return emb; + } + } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java index f58395fd5bb..5b3758f103d 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java @@ -30,7 +30,7 @@ import java.util.Map; public class QueryProfileProperties extends Properties { private final CompiledQueryProfile profile; - private final Embedder embedder; + private final Map<String, Embedder> embedders; // Note: The priority order is: values has precedence over references @@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties { private List<Pair<CompoundName, CompiledQueryProfile>> references = null; public QueryProfileProperties(CompiledQueryProfile profile) { - this(profile, Embedder.throwsOnUse); + this(profile, Embedder.throwsOnUse.asMap()); } - /** Creates an instance from a profile, throws an exception if the given profile is null */ public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) { + this(profile, Map.of(Embedder.defaultEmbedderId, embedder)); + } + + /** Creates an instance from a profile, throws an exception if the given profile is null */ + public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) { Validator.ensureNotNull("The profile wrapped by this cannot be null", profile); this.profile = profile; - this.embedder = embedder; + this.embedders = embedders; } /** Returns the query profile backing this, or null if none */ @@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties { if (fieldDescription != null) { if (i == name.size() - 1) { // at the end of the path, check the assignment type - var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context); + var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context); var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext); if (convertedValue == null && fieldDescription.getType() instanceof QueryProfileFieldType diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index 8dfb67a9d5f..5c449b9645a 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -14,14 +14,20 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, Map<String, String> context) { + this(destination, registry, Map.of(Embedder.defaultEmbedderId, embedder), context); + } + + public ConversionContext(String destination, CompiledQueryProfileRegistry registry, + Map<String, Embedder> embedders, + Map<String, String> context) { this.destination = destination; this.registry = registry; - this.embedder = embedder; + this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; } @@ -33,14 +39,14 @@ public class ConversionContext { CompiledQueryProfileRegistry registry() {return registry;} /** Returns the configured embedder, never null */ - Embedder embedder() { return embedder; } + Map<String, Embedder> embedders() { return embedders; } /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } /** Returns an empty context */ public static ConversionContext empty() { - return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of()); + return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index d6676db3774..6f1cfccc16b 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + /** * A tensor field type in a query profile * @@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType; */ public class TensorFieldType extends FieldType { + private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private final TensorType type; /** Creates a tensor field type with information about the kind of tensor this will hold */ @@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType { private Tensor encode(String s, ConversionContext context) { if ( ! s.endsWith(")")) throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); - String text = s.substring("embed(".length(), s.length() - 1); - return context.embedder().embed(text, toEmbedderContext(context), type); + String argument = s.substring("embed(".length(), s.length() - 1); + Embedder embedder; + + // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") + Matcher matcher = embedderArgumentRegexp.matcher(argument); + if (matcher.matches()) { + String embedderId = matcher.group(1); + argument = matcher.group(2); + if (!context.embedders().containsKey(embedderId)) { + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } + embedder = context.embedders().get(embedderId); + } else if (context.embedders().size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } else if (context.embedders().size() > 1) { + throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } else { + embedder = context.embedders().entrySet().stream().findFirst().get().getValue(); + } + + return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type); + } + + private static String removeQuotes(String s) { + if (s.startsWith("'") && s.endsWith("'")) { + return s.substring(1, s.length() - 1); + } + if (s.startsWith("\"") && s.endsWith("\"")) { + return s.substring(1, s.length() - 1); + } + return s; + } + + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); } private Embedder.Context toEmbedderContext(ConversionContext context) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 243915662d2..98b65c6edd9 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -34,12 +34,16 @@ public class QueryProperties extends Properties { private Query query; private final CompiledQueryProfileRegistry profileRegistry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) { + this(query, profileRegistry, Map.of(Embedder.defaultEmbedderId, embedder)); + } + + public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) { this.query = query; this.profileRegistry = profileRegistry; - this.embedder = embedder; + this.embedders = embedders; } public void setParentQuery(Query query) { @@ -394,7 +398,7 @@ public class QueryProperties extends Properties { if (type == null) return value; // no type info -> keep as string FieldDescription field = type.getField(key); if (field == null) return value; // ditto - return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context)); + return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context)); } private void throwIllegalParameter(String key,String namespace) { diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index 2e88c9fd0a4..a1556aac189 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -25,9 +25,11 @@ import org.junit.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -441,42 +443,42 @@ public class QueryProfileTypeTestCase { @Test public void testUnembeddedTensorRankFeatureInRequest() { - QueryProfile profile = new QueryProfile("test"); - profile.setType(testtype); - registry.register(profile); - - CompiledQueryProfileRegistry cRegistry = registry.compile(); - String textToEmbed = "text to embed into a tensor"; - String destinationFeature = "query(myTensor4)"; - Tensor expectedTensor = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); - Query query1 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) + - "=" + urlEncode("embed(" + textToEmbed + ")"), - com.yahoo.jdisc.http.HttpRequest.Method.GET)) - .setQueryProfile(cRegistry.getComponent("test")) - .setEmbedder(new MockEmbedder(textToEmbed, - Language.UNKNOWN, - destinationFeature, - expectedTensor)) - .build(); - assertEquals(0, query1.errors().size()); - assertEquals(expectedTensor, query1.properties().get("ranking.features.query(myTensor4)")); - assertEquals(expectedTensor, query1.getRanking().getFeatures().getTensor("query(myTensor4)").get()); - - // Explicit language - Query query2 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) + - "=" + urlEncode("embed(" + textToEmbed + ")") + - "&language=en", - com.yahoo.jdisc.http.HttpRequest.Method.GET)) - .setQueryProfile(cRegistry.getComponent("test")) - .setEmbedder(new MockEmbedder(textToEmbed, - Language.ENGLISH, - destinationFeature, - expectedTensor)) - .build(); - assertEquals(0, query2.errors().size()); - assertEquals(expectedTensor, query2.properties().get("ranking.features.query(myTensor4)")); - assertEquals(expectedTensor, query2.getRanking().getFeatures().getTensor("query(myTensor4)").get()); - + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + Tensor embedding2 = Tensor.from("tensor<float>(x[5]):[1,2,3,4,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders); + assertEmbedQuery("embed('" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(\"" + text + "\")", embedding1, embedders); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb1, \"" + text + "\")", embedding1, embedders); + assertEmbedQueryFails("embed(emb2, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb2'. Valid embedders are emb1"); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders); + assertEmbedQueryFails("embed(emb3, \"" + text + "\")", embedding1, embedders, + "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); + + // And with specified language + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1) + ); + assertEmbedQuery("embed(" + text + ")", embedding1, embedders, Language.ENGLISH.languageCode()); + + embedders = Map.of( + "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1), + "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2) + ); + assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders, Language.ENGLISH.languageCode()); + assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); } private String urlEncode(String s) { @@ -729,20 +731,52 @@ public class QueryProfileTypeTestCase { } } + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) { + assertEmbedQuery(embed, expected, embedders, null); + } + + private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { + QueryProfile profile = new QueryProfile("test"); + profile.setType(testtype); + registry.register(profile); + CompiledQueryProfileRegistry cRegistry = registry.compile(); + + String languageParam = language == null ? "" : "&language=" + language; + String destination = "query(myTensor4)"; + + Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( + "?" + urlEncode("ranking.features." + destination) + + "=" + urlEncode(embed) + + languageParam, + com.yahoo.jdisc.http.HttpRequest.Method.GET)) + .setQueryProfile(cRegistry.getComponent("test")) + .setEmbedders(embedders) + .build(); + assertEquals(0, query.errors().size()); + assertEquals(expected, query.properties().get("ranking.features." + destination)); + assertEquals(expected, query.getRanking().getFeatures().getTensor(destination).get()); + } + + private void assertEmbedQueryFails(String embed, Tensor expected, Map<String, Embedder> embedders, String errMsg) { + Throwable t = assertThrows(IllegalArgumentException.class, () -> assertEmbedQuery(embed, expected, embedders)); + while (t != null) { + if (t.getMessage().equals(errMsg)) return; + t = t.getCause(); + } + fail("Error '" + errMsg + "' not thrown"); + } + private static final class MockEmbedder implements Embedder { private final String expectedText; private final Language expectedLanguage; - private final String expectedDestination; private final Tensor tensorToReturn; public MockEmbedder(String expectedText, Language expectedLanguage, - String expectedDestination, Tensor tensorToReturn) { this.expectedText = expectedText; this.expectedLanguage = expectedLanguage; - this.expectedDestination = expectedDestination; this.tensorToReturn = tensorToReturn; } @@ -756,7 +790,6 @@ public class QueryProfileTypeTestCase { public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedText, text); assertEquals(expectedLanguage, context.getLanguage()); - assertEquals(expectedDestination, context.getDestination()); assertEquals(tensorToReturn.type(), tensorType); return tensorToReturn; } diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java index 7b553383daf..87c78445b13 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java @@ -7,6 +7,7 @@ import com.google.inject.Inject; import com.yahoo.component.chain.dependencies.After; import com.yahoo.component.chain.dependencies.Before; import com.yahoo.component.chain.dependencies.Provides; +import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.docproc.DocumentProcessor; import com.yahoo.docproc.Processing; import com.yahoo.document.Document; @@ -15,18 +16,20 @@ import com.yahoo.document.DocumentPut; import com.yahoo.document.DocumentRemove; import com.yahoo.document.DocumentType; import com.yahoo.document.DocumentTypeManager; -import com.yahoo.document.DocumentTypeManagerConfigurer; import com.yahoo.document.DocumentUpdate; -import com.yahoo.document.config.DocumentmanagerConfig; import com.yahoo.language.Linguistics; -import java.util.logging.Level; - import com.yahoo.language.process.Embedder; +import com.yahoo.language.provider.DefaultEmbedderProvider; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.indexinglanguage.AdapterFactory; import com.yahoo.vespa.indexinglanguage.SimpleAdapterFactory; import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import java.util.Map; +import java.util.logging.Level; +import java.util.stream.Collectors; + + /** * @author Simon Thoresen Hult */ @@ -55,9 +58,9 @@ public class IndexingProcessor extends DocumentProcessor { public IndexingProcessor(DocumentTypeManager documentTypeManager, IlscriptsConfig ilscriptsConfig, Linguistics linguistics, - Embedder embedder) { + ComponentRegistry<Embedder> embedders) { docTypeMgr = documentTypeManager; - scriptMgr = new ScriptManager(docTypeMgr, ilscriptsConfig, linguistics, embedder); + scriptMgr = new ScriptManager(docTypeMgr, ilscriptsConfig, linguistics, toMap(embedders)); adapterFactory = new SimpleAdapterFactory(new ExpressionSelector()); } @@ -128,4 +131,14 @@ public class IndexingProcessor extends DocumentProcessor { out.add(prev); } + private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) { + var map = embedders.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + if (map.size() > 1) { + map.remove(DefaultEmbedderProvider.class.getName()); + // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. + } + return map; + } + } diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java index 63c6d6c4bb5..de3a429e357 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java @@ -28,12 +28,12 @@ public class ScriptManager { private final Map<String, Map<String, DocumentScript>> documentFieldScripts; private final DocumentTypeManager docTypeMgr; - public ScriptManager(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, Embedder embedder) { + public ScriptManager(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, + Map<String, Embedder> embedders) { this.docTypeMgr = docTypeMgr; - documentFieldScripts = createScriptsMap(docTypeMgr, config, linguistics, embedder); + documentFieldScripts = createScriptsMap(docTypeMgr, config, linguistics, embedders); } - private Map<String, DocumentScript> getScripts(DocumentType inputType) { Map<String, DocumentScript> scripts = documentFieldScripts.get(inputType.getName()); if (scripts != null) { @@ -75,9 +75,9 @@ public class ScriptManager { private static Map<String, Map<String, DocumentScript>> createScriptsMap(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, - Embedder embedder) { + Map<String, Embedder> embedders) { Map<String, Map<String, DocumentScript>> documentFieldScripts = new HashMap<>(config.ilscript().size()); - ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedder); + ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders); parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences()); parserContext.getAnnotatorConfig().setMaxTokenLength(config.fieldmatchmaxlength()); diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java index 13f9ea1a8c8..76f4578ac87 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.docprocs.indexing; +import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.docproc.Processing; import com.yahoo.document.Document; @@ -128,6 +129,6 @@ public class IndexingProcessorTestCase { return new IndexingProcessor(new DocumentTypeManager(ConfigGetter.getConfig(DocumentmanagerConfig.class, configId)), ConfigGetter.getConfig(IlscriptsConfig.class, configId), new SimpleLinguistics(), - Embedder.throwsOnUse); + new ComponentRegistry<Embedder>()); } } diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java index a35dd0da4f3..4a7e643fb0a 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException; import org.junit.Test; import java.util.Iterator; +import java.util.Map; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -28,7 +29,7 @@ public class ScriptManagerTestCase { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newssummary") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newsarticle"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -42,7 +43,7 @@ public class ScriptManagerTestCase { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newsarticle") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newssummary"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -50,14 +51,14 @@ public class ScriptManagerTestCase { @Test public void requireThatEmptyConfigurationDoesNotThrow() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @Test public void requireThatUnknownDocumentTypeReturnsNull() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); for (Iterator<DocumentType> it = typeMgr.documentTypeIterator(); it.hasNext(); ) { assertNull(scriptMgr.getScript(it.next())); } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java index 11756ae0907..2b4e0db699b 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java @@ -62,7 +62,7 @@ public final class ScriptParser { parser.setAnnotatorConfig(context.getAnnotatorConfig()); parser.setDefaultFieldName(context.getDefaultFieldName()); parser.setLinguistics(context.getLinguistcs()); - parser.setEmbedder(context.getEmbedder()); + parser.setEmbedders(context.getEmbedders()); try { return method.call(parser); } catch (ParseException e) { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java index 91c24a10e27..9edbed68871 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java @@ -6,6 +6,9 @@ import com.yahoo.language.process.Embedder; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.vespa.indexinglanguage.parser.CharStream; +import java.util.Collections; +import java.util.Map; + /** * @author Simon Thoresen Hult */ @@ -13,13 +16,13 @@ public class ScriptParserContext { private AnnotatorConfig annotatorConfig = new AnnotatorConfig(); private Linguistics linguistics; - private final Embedder embedder; + private final Map<String, Embedder> embedders; private String defaultFieldName = null; private CharStream inputStream = null; - public ScriptParserContext(Linguistics linguistics, Embedder embedder) { + public ScriptParserContext(Linguistics linguistics, Map<String, Embedder> embedders) { this.linguistics = linguistics; - this.embedder = embedder; + this.embedders = embedders; } public AnnotatorConfig getAnnotatorConfig() { @@ -40,8 +43,8 @@ public class ScriptParserContext { return this; } - public Embedder getEmbedder() { - return embedder; + public Map<String, Embedder> getEmbedders() { + return Collections.unmodifiableMap(embedders); } public String getDefaultFieldName() { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 0da9d907718..2e4bb701454 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -12,6 +12,10 @@ import com.yahoo.language.process.Embedder; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + /** * Embeds a string in a tensor space using the configured Embedder component * @@ -20,6 +24,7 @@ import com.yahoo.tensor.TensorType; public class EmbedExpression extends Expression { private final Embedder embedder; + private final String embedderId; /** The destination the embedding will be written to on the form [schema name].[field name] */ private String destination; @@ -27,9 +32,28 @@ public class EmbedExpression extends Expression { /** The target type we are embedding into. */ private TensorType targetType; - public EmbedExpression(Embedder embedder) { + public EmbedExpression(Map<String, Embedder> embedders, String embedderId) { super(DataType.STRING); - this.embedder = embedder; + this.embedderId = embedderId; + + boolean embedderIdProvided = embedderId != null && embedderId.length() > 0; + + if (embedders.size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } + else if (embedders.size() > 1 && ! embedderIdProvided) { + this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(embedders)); + } + else if (embedders.size() == 1 && ! embedderIdProvided) { + this.embedder = embedders.entrySet().stream().findFirst().get().getValue(); + } + else if ( ! embedders.containsKey(embedderId)) { + this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + } else { + this.embedder = embedders.get(embedderId); + } } @Override @@ -71,7 +95,14 @@ public class EmbedExpression extends Expression { } @Override - public String toString() { return "embed"; } + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("embed"); + if (this.embedderId != null && this.embedderId.length() > 0) { + sb.append(" ").append(this.embedderId); + } + return sb.toString(); + } @Override public int hashCode() { return 1; } @@ -79,4 +110,11 @@ public class EmbedExpression extends Expression { @Override public boolean equals(Object o) { return o instanceof EmbedExpression; } + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); + } + } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java index a5b62c73997..e5bf4711ad1 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java @@ -15,6 +15,8 @@ import com.yahoo.vespa.indexinglanguage.parser.IndexingInput; import com.yahoo.vespa.indexinglanguage.parser.ParseException; import com.yahoo.vespa.objects.Selectable; +import java.util.Map; + /** * @author Simon Thoresen Hult */ @@ -191,11 +193,11 @@ public abstract class Expression extends Selectable { /** Creates an expression with simple lingustics for testing */ public static Expression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static Expression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static Expression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static Expression newInstance(ScriptParserContext context) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java index d8e9cc4d923..c8e45f0f61a 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java @@ -15,6 +15,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; +import java.util.Map; /** * @author Simon Thoresen Hult @@ -92,11 +93,11 @@ public final class ScriptExpression extends ExpressionList<StatementExpression> /** Creates an expression with simple lingustics for testing */ @SuppressWarnings("deprecation") public static ScriptExpression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static ScriptExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static ScriptExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static ScriptExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java index 40aa0f58413..38157531ba2 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java @@ -14,6 +14,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Map; /** * @author Simon Thoresen Hult @@ -99,11 +100,11 @@ public final class StatementExpression extends ExpressionList<Expression> { /** Creates an expression with simple lingustics for testing */ public static StatementExpression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static StatementExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static StatementExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static StatementExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index e6b21f7c07b..51bb9be1f8a 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -45,7 +45,7 @@ public class IndexingParser { private String defaultFieldName; private Linguistics linguistics; - private Embedder embedder; + private Map<String, Embedder> embedders; private AnnotatorConfig annotatorCfg; public IndexingParser(String str) { @@ -62,8 +62,8 @@ public class IndexingParser { return this; } - public IndexingParser setEmbedder(Embedder embedder) { - this.embedder = embedder; + public IndexingParser setEmbedders(Map<String, Embedder> embedders) { + this.embedders = embedders; return this; } @@ -367,10 +367,13 @@ Expression echoExp() : { } { return new EchoExpression(); } } -Expression embedExp() : { } +Expression embedExp() : { - ( <EMBED> ) - { return new EmbedExpression(embedder); } + String val = ""; +} +{ + ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] ) + { return new EmbedExpression(embedders, val); } } Expression exactExp() : { } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java index 87c54fd7abd..28da9a71aac 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java @@ -96,7 +96,7 @@ public class ScriptParserTestCase { } private static ScriptParserContext newContext(String input) { - return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse).setInputStream(new IndexingInput(input)); + return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse.asMap()).setInputStream(new IndexingInput(input)); } } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 27723c6649d..de31f6fcb1e 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -21,6 +21,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException; import org.junit.Test; import java.util.List; +import java.util.Map; import static org.junit.Assert.*; @@ -175,37 +176,66 @@ public class ScriptTestCase { @Test public void testEmbed() throws ParseException { - TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); - var expression = Expression.fromString("input myText | embed | attribute 'myTensor'", - new SimpleLinguistics(), - new MockEmbedder("myDocument.myTensor")); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myText", DataType.STRING)); - var tensorField = new Field("myTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - adapter.setValue("myText", new StringFieldValue("input text")); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + // Test parsing without knowledge of any embedders + String exp = "input myText | embed emb1 | attribute 'myTensor'"; + Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + + Map<String, Embedder> embedder = Map.of( + "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]") + ); + testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "[1,2,0,0]"); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, "[1,2,0,0]"); + testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, "[1,2,0,0]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]"), + "emb2", new MockEmbedder("myDocument.myTensor", "[3,4,5,0]") + ); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "[1,2,0,0]"); + testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, "[3,4,5,0]"); + + assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "[3,4,5,0]"), + "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1,emb2"); + assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "[3,4,5,0]"), + "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); + } - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(new StringFieldValue("input text")); - expression.execute(context); - assertTrue(adapter.values.containsKey("myTensor")); - assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), - ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); + private void testEmbedStatement(String exp, Map<String, Embedder> embedders, String expected) { + try { + var expression = Expression.fromString(exp, new SimpleLinguistics(), embedders); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var tensorField = new Field("myTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + adapter.setValue("myText", new StringFieldValue("input text")); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(new StringFieldValue("input text")); + expression.execute(context); + assertTrue(adapter.values.containsKey("myTensor")); + assertEquals(Tensor.from(tensorType, expected), + ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } } @SuppressWarnings("unchecked") @Test public void testArrayEmbed() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray", "[7,3,0,0]")); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", new SimpleLinguistics(), - new MockEmbedder("myDocument.myTensorArray")); + embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -235,9 +265,11 @@ public class ScriptTestCase { private static class MockEmbedder implements Embedder { private final String expectedDestination; + private final String tensorString; - public MockEmbedder(String expectedDestination) { + public MockEmbedder(String expectedDestination, String tensorString) { this.expectedDestination = expectedDestination; + this.tensorString = tensorString; } @Override @@ -248,9 +280,18 @@ public class ScriptTestCase { @Override public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedDestination, context.getDestination()); - return Tensor.from(tensorType, "[7,3,0,0]"); + return Tensor.from(tensorType, tensorString); } } + private void assertThrows(Runnable r, String msg) { + try { + r.run(); + fail(); + } catch (IllegalStateException e) { + assertEquals(e.getMessage(), msg); + } + } + } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java index f6aa7e477a8..89170027c73 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java @@ -19,7 +19,7 @@ public class DefaultFieldNameTestCase { public void requireThatDefaultFieldNameIsAppliedWhenArgumentIsMissing() throws ParseException { IndexingInput input = new IndexingInput("input"); InputExpression exp = (InputExpression)Expression.newInstance(new ScriptParserContext(new SimpleLinguistics(), - Embedder.throwsOnUse) + Embedder.throwsOnUse.asMap()) .setInputStream(input) .setDefaultFieldName("foo")); assertEquals("foo", exp.getFieldName()); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java index e333eea7001..7db026d43ee 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java @@ -85,9 +85,9 @@ public class ExpressionTestCase { private static void assertExpression(Class expectedClass, String str) throws ParseException { Linguistics linguistics = new SimpleLinguistics(); - Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse); + Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse.asMap()); assertEquals(expectedClass, foo.getClass()); - Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse); + Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse.asMap()); assertEquals(foo.hashCode(), bar.hashCode()); assertEquals(foo, bar); } diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 910056286ec..c3e489b8dd9 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -354,6 +354,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(java.lang.String)", "public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)", "public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)" ], @@ -368,10 +369,13 @@ "abstract" ], "methods": [ + "public java.util.Map asMap()", + "public java.util.Map asMap(java.lang.String)", "public abstract java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)", "public abstract com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)" ], "fields": [ + "public static final java.lang.String defaultEmbedderId", "public static final com.yahoo.language.process.Embedder throwsOnUse" ] }, diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java index dd9c3847314..238698e898a 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java @@ -3,10 +3,10 @@ package com.yahoo.language.process; import com.yahoo.language.Language; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import java.util.List; +import java.util.Map; /** * An embedder converts a text string to a tensor @@ -15,9 +15,22 @@ import java.util.List; */ public interface Embedder { + /** Name of embedder when none is explicity given */ + String defaultEmbedderId = "default"; + /** An instance of this which throws IllegalStateException if attempted used */ Embedder throwsOnUse = new FailingEmbedder(); + /** Returns this embedder instance as a map with the default embedder name */ + default Map<String, Embedder> asMap() { + return asMap(defaultEmbedderId); + } + + /** Returns this embedder instance as a map with the given name */ + default Map<String, Embedder> asMap(String name) { + return Map.of(name, this); + } + /** * Converts text into a list of token id's (a vector embedding) * @@ -82,14 +95,24 @@ public interface Embedder { class FailingEmbedder implements Embedder { + private final String message; + + public FailingEmbedder() { + this("No embedder has been configured"); + } + + public FailingEmbedder(String message) { + this.message = message; + } + @Override public List<Integer> embed(String text, Context context) { - throw new IllegalStateException("No embedder has been configured"); + throw new IllegalStateException(message); } @Override public Tensor embed(String text, Context context, TensorType tensorType) { - throw new IllegalStateException("No embedder has been configured"); + throw new IllegalStateException(message); } } |