aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java8
-rw-r--r--config-model/src/main/javacc/IntermediateParser.jj6
-rw-r--r--config-model/src/main/javacc/SDParser.jj6
-rw-r--r--container-search/abi-spec.json7
-rw-r--r--container-search/src/main/java/com/yahoo/search/Query.java33
-rw-r--r--container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java46
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java50
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java10
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java113
-rw-r--r--docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java25
-rw-r--r--docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java10
-rw-r--r--docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java3
-rw-r--r--docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java9
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java2
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java13
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java44
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java8
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java7
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java7
-rw-r--r--indexinglanguage/src/main/javacc/IndexingParser.jj15
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java2
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java89
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java2
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java4
-rw-r--r--linguistics/abi-spec.json4
-rw-r--r--linguistics/src/main/java/com/yahoo/language/process/Embedder.java29
29 files changed, 432 insertions, 156 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java
index 49ae00d0663..256c628a1cb 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java
@@ -415,14 +415,14 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer,
return wasConfiguredToDoAttributing;
}
- /** Parse an indexing expression which will use the simple linguistics implementatino suitable for testing */
+ /** Parse an indexing expression which will use the simple linguistics implementation suitable for testing */
public void parseIndexingScript(String script) {
- parseIndexingScript(script, new SimpleLinguistics(), Embedder.throwsOnUse);
+ parseIndexingScript(script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
- public void parseIndexingScript(String script, Linguistics linguistics, Embedder embedder) {
+ public void parseIndexingScript(String script, Linguistics linguistics, Map<String, Embedder> embedders) {
try {
- ScriptParserContext config = new ScriptParserContext(linguistics, embedder);
+ ScriptParserContext config = new ScriptParserContext(linguistics, embedders);
config.setInputStream(new IndexingInput(script));
setIndexingScript(ScriptExpression.newInstance(config));
} catch (ParseException e) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java
index a5f5f961ab5..cdd3cc386a4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java
@@ -13,6 +13,8 @@ import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression;
import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig;
import com.yahoo.yolean.Exceptions;
+import java.util.Map;
+
/**
* @author Einar M R Rosenvinge
*/
@@ -32,13 +34,13 @@ public class IndexingOperation implements FieldOperation {
/** Creates an indexing operation which will use the simple linguistics implementation suitable for testing */
public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine) throws ParseException {
- return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine,
- Linguistics linguistics, Embedder embedder)
+ Linguistics linguistics, Map<String, Embedder> embedders)
throws ParseException {
- ScriptParserContext config = new ScriptParserContext(linguistics, embedder);
+ ScriptParserContext config = new ScriptParserContext(linguistics, embedders);
config.setAnnotatorConfig(new AnnotatorConfig());
config.setInputStream(input);
ScriptExpression exp;
diff --git a/config-model/src/main/javacc/IntermediateParser.jj b/config-model/src/main/javacc/IntermediateParser.jj
index ba955f071b2..8a4798d6f74 100644
--- a/config-model/src/main/javacc/IntermediateParser.jj
+++ b/config-model/src/main/javacc/IntermediateParser.jj
@@ -81,7 +81,7 @@ public class IntermediateParser {
*/
@SuppressWarnings("deprecation")
private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException {
- return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
/**
@@ -90,13 +90,13 @@ public class IntermediateParser {
* @param multiline Whether or not to allow multi-line expressions.
* @param linguistics What to use for tokenizing.
*/
- private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Embedder embedder) throws ParseException {
+ private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
SimpleCharStream input = (SimpleCharStream)token_source.input_stream;
if (token.next != null) {
input.backup(token.next.image.length());
}
try {
- return IndexingOperation.fromStream(input, multiline, linguistics, embedder);
+ return IndexingOperation.fromStream(input, multiline, linguistics, embedders);
} finally {
token.next = null;
jj_ntk = -1;
diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj
index ab0cdefc355..aeffe6e5c39 100644
--- a/config-model/src/main/javacc/SDParser.jj
+++ b/config-model/src/main/javacc/SDParser.jj
@@ -112,7 +112,7 @@ public class SDParser {
*/
@SuppressWarnings("deprecation")
private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException {
- return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
/**
@@ -121,13 +121,13 @@ public class SDParser {
* @param multiline Whether or not to allow multi-line expressions.
* @param linguistics What to use for tokenizing.
*/
- private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Embedder embedder) throws ParseException {
+ private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
SimpleCharStream input = (SimpleCharStream)token_source.input_stream;
if (token.next != null) {
input.backup(token.next.image.length());
}
try {
- return IndexingOperation.fromStream(input, multiline, linguistics, embedder);
+ return IndexingOperation.fromStream(input, multiline, linguistics, embedders);
} finally {
token.next = null;
jj_ntk = -1;
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 5b990ca8758..b7aa1a8d0ef 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -1868,7 +1868,9 @@
"public com.yahoo.search.Query$Builder setQueryProfile(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)",
"public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()",
"public com.yahoo.search.Query$Builder setEmbedder(com.yahoo.language.process.Embedder)",
+ "public com.yahoo.search.Query$Builder setEmbedders(java.util.Map)",
"public com.yahoo.language.process.Embedder getEmbedder()",
+ "public java.util.Map getEmbedders()",
"public com.yahoo.search.Query$Builder setZoneInfo(ai.vespa.cloud.ZoneInfo)",
"public ai.vespa.cloud.ZoneInfo getZoneInfo()",
"public com.yahoo.search.Query build()"
@@ -4354,7 +4356,7 @@
"public"
],
"methods": [
- "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)",
+ "public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.component.provider.ComponentRegistry, com.yahoo.search.searchchain.ExecutionFactory, ai.vespa.cloud.ZoneInfo)",
"public void <init>(com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)",
"public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.language.process.Embedder, com.yahoo.search.searchchain.ExecutionFactory)",
"public void <init>(com.yahoo.statistics.Statistics, com.yahoo.jdisc.Metric, com.yahoo.container.handler.threadpool.ContainerThreadPool, com.yahoo.container.logging.AccessLog, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.container.core.ContainerHttpConfig, com.yahoo.search.searchchain.ExecutionFactory)",
@@ -5993,6 +5995,7 @@
"methods": [
"public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile)",
"public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, com.yahoo.language.process.Embedder)",
+ "public void <init>(com.yahoo.search.query.profile.compiled.CompiledQueryProfile, java.util.Map)",
"public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()",
"public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)",
"public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)",
@@ -6368,6 +6371,7 @@
],
"methods": [
"public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder, java.util.Map)",
+ "public void <init>(java.lang.String, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map, java.util.Map)",
"public java.lang.String destination()",
"public static com.yahoo.search.query.profile.types.ConversionContext empty()"
],
@@ -6642,6 +6646,7 @@
],
"methods": [
"public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, com.yahoo.language.process.Embedder)",
+ "public void <init>(com.yahoo.search.Query, com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry, java.util.Map)",
"public void setParentQuery(com.yahoo.search.Query)",
"public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)",
"public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)",
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java
index 83fa18d847f..553c73dac17 100644
--- a/container-search/src/main/java/com/yahoo/search/Query.java
+++ b/container-search/src/main/java/com/yahoo/search/Query.java
@@ -337,7 +337,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
public Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile) {
super(new QueryPropertyAliases(propertyAliases));
this.httpRequest = request;
- init(requestMap, queryProfile, Embedder.throwsOnUse, ZoneInfo.defaultInfo());
+ init(requestMap, queryProfile, Embedder.throwsOnUse.asMap(), ZoneInfo.defaultInfo());
}
// TODO: Deprecate most constructors above here
@@ -346,31 +346,31 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
this(builder.getRequest(),
builder.getRequestMap(),
builder.getQueryProfile(),
- builder.getEmbedder(),
+ builder.getEmbedders(),
builder.getZoneInfo());
}
- private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Embedder embedder,
+ private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders,
ZoneInfo zoneInfo) {
super(new QueryPropertyAliases(propertyAliases));
this.httpRequest = request;
- init(requestMap, queryProfile, embedder, zoneInfo);
+ init(requestMap, queryProfile, embedders, zoneInfo);
}
private void init(Map<String, String> requestMap,
CompiledQueryProfile queryProfile,
- Embedder embedder,
+ Map<String, Embedder> embedders,
ZoneInfo zoneInfo) {
startTime = httpRequest.getJDiscRequest().creationTime(TimeUnit.MILLISECONDS);
if (queryProfile != null) {
// Move all request parameters to the query profile just to validate that the parameter settings are legal
- Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedder);
+ Properties queryProfileProperties = new QueryProfileProperties(queryProfile, embedders);
properties().chain(queryProfileProperties);
// TODO: Just checking legality rather than actually setting would be faster
setPropertiesFromRequestMap(requestMap, properties(), true); // Adds errors to the query for illegal set attempts
// Create the full chain
- properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedder)).
+ properties().chain(new QueryProperties(this, queryProfile.getRegistry(), embedders)).
chain(new ModelObjectMap()).
chain(new RequestContextProperties(requestMap, zoneInfo)).
chain(queryProfileProperties).
@@ -389,7 +389,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
}
else { // bypass these complications if there is no query profile to get values from and validate against
properties().
- chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedder)).
+ chain(new QueryProperties(this, CompiledQueryProfileRegistry.empty, embedders)).
chain(new PropertyMap()).
chain(new DefaultProperties());
setPropertiesFromRequestMap(requestMap, properties(), false);
@@ -1131,7 +1131,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
private HttpRequest request = null;
private Map<String, String> requestMap = null;
private CompiledQueryProfile queryProfile = null;
- private Embedder embedder = Embedder.throwsOnUse;
+ private Map<String, Embedder> embedders = Embedder.throwsOnUse.asMap();
private ZoneInfo zoneInfo = ZoneInfo.defaultInfo();
public Builder setRequest(String query) {
@@ -1171,11 +1171,22 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
public CompiledQueryProfile getQueryProfile() { return queryProfile; }
public Builder setEmbedder(Embedder embedder) {
- this.embedder = embedder;
+ return setEmbedders(Map.of(Embedder.defaultEmbedderId, embedder));
+ }
+
+ public Builder setEmbedders(Map<String, Embedder> embedders) {
+ this.embedders = embedders;
return this;
}
- public Embedder getEmbedder() { return embedder; }
+ public Embedder getEmbedder() {
+ if (embedders.size() != 1) {
+ throw new IllegalArgumentException("Attempt to get single embedder but multiple exists.");
+ }
+ return embedders.entrySet().stream().findFirst().get().getValue();
+ }
+
+ public Map<String, Embedder> getEmbedders() { return embedders; }
public Builder setZoneInfo(ZoneInfo zoneInfo) {
this.zoneInfo = zoneInfo;
diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
index b65953935f0..af6374ba245 100644
--- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
+++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
@@ -3,6 +3,7 @@ package com.yahoo.search.handler;
import com.google.inject.Inject;
import com.yahoo.collections.Tuple2;
+import com.yahoo.component.ComponentId;
import com.yahoo.component.ComponentSpecification;
import com.yahoo.component.Vtag;
import com.yahoo.component.chain.Chain;
@@ -24,6 +25,7 @@ import com.yahoo.jdisc.Metric;
import com.yahoo.jdisc.Request;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
+import com.yahoo.language.provider.DefaultEmbedderProvider;
import com.yahoo.net.HostName;
import com.yahoo.net.UriTools;
import com.yahoo.prelude.query.parser.ParseException;
@@ -57,6 +59,7 @@ import ai.vespa.cloud.ZoneInfo;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
@@ -66,6 +69,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
* Handles search request.
@@ -102,7 +106,7 @@ public class SearchHandler extends LoggingRequestHandler {
private final Optional<String> hostResponseHeaderKey;
private final String selfHostname = HostName.getLocalhost();
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final ExecutionFactory executionFactory;
private final AtomicLong numRequestsLeftToTrace;
@@ -117,10 +121,10 @@ public class SearchHandler extends LoggingRequestHandler {
ContainerThreadPool threadpool,
CompiledQueryProfileRegistry queryProfileRegistry,
ContainerHttpConfig config,
- Embedder embedder,
+ ComponentRegistry<Embedder> embedders,
ExecutionFactory executionFactory,
ZoneInfo zoneInfo) {
- this(metric, threadpool.executor(), queryProfileRegistry, embedder, executionFactory,
+ this(metric, threadpool.executor(), queryProfileRegistry, embedders, executionFactory,
config.numQueriesToTraceOnDebugAfterConstruction(),
config.hostResponseHeaderKey().equals("") ? Optional.empty() : Optional.of(config.hostResponseHeaderKey()),
zoneInfo);
@@ -221,7 +225,7 @@ public class SearchHandler extends LoggingRequestHandler {
CompiledQueryProfileRegistry queryProfileRegistry,
ExecutionFactory executionFactory,
Optional<String> hostResponseHeaderKey) {
- this(metric, executor, queryProfileRegistry, Embedder.throwsOnUse,
+ this(metric, executor, queryProfileRegistry, toRegistry(Embedder.throwsOnUse),
executionFactory, 0, hostResponseHeaderKey,
ZoneInfo.defaultInfo());
}
@@ -234,10 +238,24 @@ public class SearchHandler extends LoggingRequestHandler {
long numQueriesToTraceOnDebugAfterStartup,
Optional<String> hostResponseHeaderKey,
ZoneInfo zoneInfo) {
+ this(metric, executor, queryProfileRegistry, toRegistry(embedder),
+ executionFactory, numQueriesToTraceOnDebugAfterStartup, hostResponseHeaderKey,
+ ZoneInfo.defaultInfo());
+ }
+
+ private SearchHandler(Metric metric,
+ Executor executor,
+ CompiledQueryProfileRegistry queryProfileRegistry,
+ ComponentRegistry<Embedder> embedders,
+ ExecutionFactory executionFactory,
+ long numQueriesToTraceOnDebugAfterStartup,
+ Optional<String> hostResponseHeaderKey,
+ ZoneInfo zoneInfo) {
super(executor, metric, true);
+
log.log(Level.FINE, () -> "SearchHandler.init " + System.identityHashCode(this));
this.queryProfileRegistry = queryProfileRegistry;
- this.embedder = embedder;
+ this.embedders = toMap(embedders);
this.executionFactory = executionFactory;
this.maxThreads = examineExecutor(executor);
@@ -340,7 +358,7 @@ public class SearchHandler extends LoggingRequestHandler {
Query query = new Query.Builder().setRequest(request)
.setRequestMap(requestMap)
.setQueryProfile(queryProfile)
- .setEmbedder(embedder)
+ .setEmbedders(embedders)
.setZoneInfo(zoneInfo)
.build();
@@ -691,6 +709,22 @@ public class SearchHandler extends LoggingRequestHandler {
.build();
}
+ private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
+ var map = embedders.allComponentsById().entrySet().stream()
+ .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
+ if (map.size() > 1) {
+ map.remove(DefaultEmbedderProvider.class.getName());
+ // Ideally, this should be handled by dependency injection, however for now this workaround is necessary.
+ }
+ return Collections.unmodifiableMap(map);
+ }
+
+ private static ComponentRegistry<Embedder> toRegistry(Embedder embedder) {
+ ComponentRegistry<Embedder> emb = new ComponentRegistry<>();
+ emb.register(new ComponentId(Embedder.defaultEmbedderId), embedder);
+ return emb;
+ }
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
index f58395fd5bb..5b3758f103d 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
@@ -30,7 +30,7 @@ import java.util.Map;
public class QueryProfileProperties extends Properties {
private final CompiledQueryProfile profile;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
// Note: The priority order is: values has precedence over references
@@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties {
private List<Pair<CompoundName, CompiledQueryProfile>> references = null;
public QueryProfileProperties(CompiledQueryProfile profile) {
- this(profile, Embedder.throwsOnUse);
+ this(profile, Embedder.throwsOnUse.asMap());
}
- /** Creates an instance from a profile, throws an exception if the given profile is null */
public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) {
+ this(profile, Map.of(Embedder.defaultEmbedderId, embedder));
+ }
+
+ /** Creates an instance from a profile, throws an exception if the given profile is null */
+ public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) {
Validator.ensureNotNull("The profile wrapped by this cannot be null", profile);
this.profile = profile;
- this.embedder = embedder;
+ this.embedders = embedders;
}
/** Returns the query profile backing this, or null if none */
@@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties {
if (fieldDescription != null) {
if (i == name.size() - 1) { // at the end of the path, check the assignment type
- var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context);
+ var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context);
var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext);
if (convertedValue == null
&& fieldDescription.getType() instanceof QueryProfileFieldType
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
index 8dfb67a9d5f..5c449b9645a 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
@@ -14,14 +14,20 @@ public class ConversionContext {
private final String destination;
private final CompiledQueryProfileRegistry registry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final Language language;
public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder,
Map<String, String> context) {
+ this(destination, registry, Map.of(Embedder.defaultEmbedderId, embedder), context);
+ }
+
+ public ConversionContext(String destination, CompiledQueryProfileRegistry registry,
+ Map<String, Embedder> embedders,
+ Map<String, String> context) {
this.destination = destination;
this.registry = registry;
- this.embedder = embedder;
+ this.embedders = embedders;
this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language"))
: Language.UNKNOWN;
}
@@ -33,14 +39,14 @@ public class ConversionContext {
CompiledQueryProfileRegistry registry() {return registry;}
/** Returns the configured embedder, never null */
- Embedder embedder() { return embedder; }
+ Map<String, Embedder> embedders() { return embedders; }
/** Returns the language, which is never null but may be UNKNOWN */
Language language() { return language; }
/** Returns an empty context */
public static ConversionContext empty() {
- return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of());
+ return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of());
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index d6676db3774..6f1cfccc16b 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
/**
* A tensor field type in a query profile
*
@@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType;
*/
public class TensorFieldType extends FieldType {
+ private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])");
+
private final TensorType type;
/** Creates a tensor field type with information about the kind of tensor this will hold */
@@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType {
private Tensor encode(String s, ConversionContext context) {
if ( ! s.endsWith(")"))
throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
- String text = s.substring("embed(".length(), s.length() - 1);
- return context.embedder().embed(text, toEmbedderContext(context), type);
+ String argument = s.substring("embed(".length(), s.length() - 1);
+ Embedder embedder;
+
+ // Check if arguments specifies an embedder with the format embed(embedder, "text to encode")
+ Matcher matcher = embedderArgumentRegexp.matcher(argument);
+ if (matcher.matches()) {
+ String embedderId = matcher.group(1);
+ argument = matcher.group(2);
+ if (!context.embedders().containsKey(embedderId)) {
+ throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ }
+ embedder = context.embedders().get(embedderId);
+ } else if (context.embedders().size() == 0) {
+ throw new IllegalStateException("No embedders provided"); // should never happen
+ } else if (context.embedders().size() > 1) {
+ throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ } else {
+ embedder = context.embedders().entrySet().stream().findFirst().get().getValue();
+ }
+
+ return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type);
+ }
+
+ private static String removeQuotes(String s) {
+ if (s.startsWith("'") && s.endsWith("'")) {
+ return s.substring(1, s.length() - 1);
+ }
+ if (s.startsWith("\"") && s.endsWith("\"")) {
+ return s.substring(1, s.length() - 1);
+ }
+ return s;
+ }
+
+ private static String validEmbedders(Map<String, Embedder> embedders) {
+ List<String> embedderIds = new ArrayList<>();
+ embedders.forEach((key, value) -> embedderIds.add(key));
+ embedderIds.sort(null);
+ return String.join(",", embedderIds);
}
private Embedder.Context toEmbedderContext(ConversionContext context) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index 243915662d2..98b65c6edd9 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -34,12 +34,16 @@ public class QueryProperties extends Properties {
private Query query;
private final CompiledQueryProfileRegistry profileRegistry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) {
+ this(query, profileRegistry, Map.of(Embedder.defaultEmbedderId, embedder));
+ }
+
+ public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) {
this.query = query;
this.profileRegistry = profileRegistry;
- this.embedder = embedder;
+ this.embedders = embedders;
}
public void setParentQuery(Query query) {
@@ -394,7 +398,7 @@ public class QueryProperties extends Properties {
if (type == null) return value; // no type info -> keep as string
FieldDescription field = type.getField(key);
if (field == null) return value; // ditto
- return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context));
+ return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context));
}
private void throwIllegalParameter(String key,String namespace) {
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
index 2e88c9fd0a4..a1556aac189 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
@@ -25,9 +25,11 @@ import org.junit.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -441,42 +443,42 @@ public class QueryProfileTypeTestCase {
@Test
public void testUnembeddedTensorRankFeatureInRequest() {
- QueryProfile profile = new QueryProfile("test");
- profile.setType(testtype);
- registry.register(profile);
-
- CompiledQueryProfileRegistry cRegistry = registry.compile();
- String textToEmbed = "text to embed into a tensor";
- String destinationFeature = "query(myTensor4)";
- Tensor expectedTensor = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]");
- Query query1 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) +
- "=" + urlEncode("embed(" + textToEmbed + ")"),
- com.yahoo.jdisc.http.HttpRequest.Method.GET))
- .setQueryProfile(cRegistry.getComponent("test"))
- .setEmbedder(new MockEmbedder(textToEmbed,
- Language.UNKNOWN,
- destinationFeature,
- expectedTensor))
- .build();
- assertEquals(0, query1.errors().size());
- assertEquals(expectedTensor, query1.properties().get("ranking.features.query(myTensor4)"));
- assertEquals(expectedTensor, query1.getRanking().getFeatures().getTensor("query(myTensor4)").get());
-
- // Explicit language
- Query query2 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) +
- "=" + urlEncode("embed(" + textToEmbed + ")") +
- "&language=en",
- com.yahoo.jdisc.http.HttpRequest.Method.GET))
- .setQueryProfile(cRegistry.getComponent("test"))
- .setEmbedder(new MockEmbedder(textToEmbed,
- Language.ENGLISH,
- destinationFeature,
- expectedTensor))
- .build();
- assertEquals(0, query2.errors().size());
- assertEquals(expectedTensor, query2.properties().get("ranking.features.query(myTensor4)"));
- assertEquals(expectedTensor, query2.getRanking().getFeatures().getTensor("query(myTensor4)").get());
-
+ String text = "text to embed into a tensor";
+ Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]");
+ Tensor embedding2 = Tensor.from("tensor<float>(x[5]):[1,2,3,4,0]]");
+
+ Map<String, Embedder> embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1)
+ );
+ assertEmbedQuery("embed(" + text + ")", embedding1, embedders);
+ assertEmbedQuery("embed('" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(\"" + text + "\")", embedding1, embedders);
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(emb1, \"" + text + "\")", embedding1, embedders);
+ assertEmbedQueryFails("embed(emb2, \"" + text + "\")", embedding1, embedders,
+ "Can't find embedder 'emb2'. Valid embedders are emb1");
+
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1),
+ "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2)
+ );
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders);
+ assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders);
+ assertEmbedQueryFails("embed(emb3, \"" + text + "\")", embedding1, embedders,
+ "Can't find embedder 'emb3'. Valid embedders are emb1,emb2");
+
+ // And with specified language
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1)
+ );
+ assertEmbedQuery("embed(" + text + ")", embedding1, embedders, Language.ENGLISH.languageCode());
+
+ embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.ENGLISH, embedding1),
+ "emb2", new MockEmbedder(text, Language.UNKNOWN, embedding2)
+ );
+ assertEmbedQuery("embed(emb1, '" + text + "')", embedding1, embedders, Language.ENGLISH.languageCode());
+ assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode());
}
private String urlEncode(String s) {
@@ -729,20 +731,52 @@ public class QueryProfileTypeTestCase {
}
}
+ private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders) {
+ assertEmbedQuery(embed, expected, embedders, null);
+ }
+
+ private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) {
+ QueryProfile profile = new QueryProfile("test");
+ profile.setType(testtype);
+ registry.register(profile);
+ CompiledQueryProfileRegistry cRegistry = registry.compile();
+
+ String languageParam = language == null ? "" : "&language=" + language;
+ String destination = "query(myTensor4)";
+
+ Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest(
+ "?" + urlEncode("ranking.features." + destination) +
+ "=" + urlEncode(embed) +
+ languageParam,
+ com.yahoo.jdisc.http.HttpRequest.Method.GET))
+ .setQueryProfile(cRegistry.getComponent("test"))
+ .setEmbedders(embedders)
+ .build();
+ assertEquals(0, query.errors().size());
+ assertEquals(expected, query.properties().get("ranking.features." + destination));
+ assertEquals(expected, query.getRanking().getFeatures().getTensor(destination).get());
+ }
+
+ private void assertEmbedQueryFails(String embed, Tensor expected, Map<String, Embedder> embedders, String errMsg) {
+ Throwable t = assertThrows(IllegalArgumentException.class, () -> assertEmbedQuery(embed, expected, embedders));
+ while (t != null) {
+ if (t.getMessage().equals(errMsg)) return;
+ t = t.getCause();
+ }
+ fail("Error '" + errMsg + "' not thrown");
+ }
+
private static final class MockEmbedder implements Embedder {
private final String expectedText;
private final Language expectedLanguage;
- private final String expectedDestination;
private final Tensor tensorToReturn;
public MockEmbedder(String expectedText,
Language expectedLanguage,
- String expectedDestination,
Tensor tensorToReturn) {
this.expectedText = expectedText;
this.expectedLanguage = expectedLanguage;
- this.expectedDestination = expectedDestination;
this.tensorToReturn = tensorToReturn;
}
@@ -756,7 +790,6 @@ public class QueryProfileTypeTestCase {
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
assertEquals(expectedText, text);
assertEquals(expectedLanguage, context.getLanguage());
- assertEquals(expectedDestination, context.getDestination());
assertEquals(tensorToReturn.type(), tensorType);
return tensorToReturn;
}
diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java
index 7b553383daf..87c78445b13 100644
--- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java
+++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java
@@ -7,6 +7,7 @@ import com.google.inject.Inject;
import com.yahoo.component.chain.dependencies.After;
import com.yahoo.component.chain.dependencies.Before;
import com.yahoo.component.chain.dependencies.Provides;
+import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.docproc.DocumentProcessor;
import com.yahoo.docproc.Processing;
import com.yahoo.document.Document;
@@ -15,18 +16,20 @@ import com.yahoo.document.DocumentPut;
import com.yahoo.document.DocumentRemove;
import com.yahoo.document.DocumentType;
import com.yahoo.document.DocumentTypeManager;
-import com.yahoo.document.DocumentTypeManagerConfigurer;
import com.yahoo.document.DocumentUpdate;
-import com.yahoo.document.config.DocumentmanagerConfig;
import com.yahoo.language.Linguistics;
-import java.util.logging.Level;
-
import com.yahoo.language.process.Embedder;
+import com.yahoo.language.provider.DefaultEmbedderProvider;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
import com.yahoo.vespa.indexinglanguage.AdapterFactory;
import com.yahoo.vespa.indexinglanguage.SimpleAdapterFactory;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
+import java.util.Map;
+import java.util.logging.Level;
+import java.util.stream.Collectors;
+
+
/**
* @author Simon Thoresen Hult
*/
@@ -55,9 +58,9 @@ public class IndexingProcessor extends DocumentProcessor {
public IndexingProcessor(DocumentTypeManager documentTypeManager,
IlscriptsConfig ilscriptsConfig,
Linguistics linguistics,
- Embedder embedder) {
+ ComponentRegistry<Embedder> embedders) {
docTypeMgr = documentTypeManager;
- scriptMgr = new ScriptManager(docTypeMgr, ilscriptsConfig, linguistics, embedder);
+ scriptMgr = new ScriptManager(docTypeMgr, ilscriptsConfig, linguistics, toMap(embedders));
adapterFactory = new SimpleAdapterFactory(new ExpressionSelector());
}
@@ -128,4 +131,14 @@ public class IndexingProcessor extends DocumentProcessor {
out.add(prev);
}
+ private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
+ var map = embedders.allComponentsById().entrySet().stream()
+ .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
+ if (map.size() > 1) {
+ map.remove(DefaultEmbedderProvider.class.getName());
+ // Ideally, this should be handled by dependency injection, however for now this workaround is necessary.
+ }
+ return map;
+ }
+
}
diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java
index 63c6d6c4bb5..de3a429e357 100644
--- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java
+++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java
@@ -28,12 +28,12 @@ public class ScriptManager {
private final Map<String, Map<String, DocumentScript>> documentFieldScripts;
private final DocumentTypeManager docTypeMgr;
- public ScriptManager(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, Embedder embedder) {
+ public ScriptManager(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics,
+ Map<String, Embedder> embedders) {
this.docTypeMgr = docTypeMgr;
- documentFieldScripts = createScriptsMap(docTypeMgr, config, linguistics, embedder);
+ documentFieldScripts = createScriptsMap(docTypeMgr, config, linguistics, embedders);
}
-
private Map<String, DocumentScript> getScripts(DocumentType inputType) {
Map<String, DocumentScript> scripts = documentFieldScripts.get(inputType.getName());
if (scripts != null) {
@@ -75,9 +75,9 @@ public class ScriptManager {
private static Map<String, Map<String, DocumentScript>> createScriptsMap(DocumentTypeManager docTypeMgr,
IlscriptsConfig config,
Linguistics linguistics,
- Embedder embedder) {
+ Map<String, Embedder> embedders) {
Map<String, Map<String, DocumentScript>> documentFieldScripts = new HashMap<>(config.ilscript().size());
- ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedder);
+ ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders);
parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences());
parserContext.getAnnotatorConfig().setMaxTokenLength(config.fieldmatchmaxlength());
diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java
index 13f9ea1a8c8..76f4578ac87 100644
--- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java
+++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.docprocs.indexing;
+import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.docproc.Processing;
import com.yahoo.document.Document;
@@ -128,6 +129,6 @@ public class IndexingProcessorTestCase {
return new IndexingProcessor(new DocumentTypeManager(ConfigGetter.getConfig(DocumentmanagerConfig.class, configId)),
ConfigGetter.getConfig(IlscriptsConfig.class, configId),
new SimpleLinguistics(),
- Embedder.throwsOnUse);
+ new ComponentRegistry<Embedder>());
}
}
diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java
index a35dd0da4f3..4a7e643fb0a 100644
--- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java
+++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java
@@ -9,6 +9,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException;
import org.junit.Test;
import java.util.Iterator;
+import java.util.Map;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
@@ -28,7 +29,7 @@ public class ScriptManagerTestCase {
IlscriptsConfig.Builder config = new IlscriptsConfig.Builder();
config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newssummary")
.content("input title | index title"));
- ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse);
+ ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap());
assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newsarticle")));
assertNull(scriptMgr.getScript(new DocumentType("unknown")));
}
@@ -42,7 +43,7 @@ public class ScriptManagerTestCase {
IlscriptsConfig.Builder config = new IlscriptsConfig.Builder();
config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newsarticle")
.content("input title | index title"));
- ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse);
+ ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap());
assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newssummary")));
assertNull(scriptMgr.getScript(new DocumentType("unknown")));
}
@@ -50,14 +51,14 @@ public class ScriptManagerTestCase {
@Test
public void requireThatEmptyConfigurationDoesNotThrow() {
var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg");
- ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse);
+ ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap());
assertNull(scriptMgr.getScript(new DocumentType("unknown")));
}
@Test
public void requireThatUnknownDocumentTypeReturnsNull() {
var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg");
- ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse);
+ ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap());
for (Iterator<DocumentType> it = typeMgr.documentTypeIterator(); it.hasNext(); ) {
assertNull(scriptMgr.getScript(it.next()));
}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java
index 11756ae0907..2b4e0db699b 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java
@@ -62,7 +62,7 @@ public final class ScriptParser {
parser.setAnnotatorConfig(context.getAnnotatorConfig());
parser.setDefaultFieldName(context.getDefaultFieldName());
parser.setLinguistics(context.getLinguistcs());
- parser.setEmbedder(context.getEmbedder());
+ parser.setEmbedders(context.getEmbedders());
try {
return method.call(parser);
} catch (ParseException e) {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java
index 91c24a10e27..9edbed68871 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java
@@ -6,6 +6,9 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig;
import com.yahoo.vespa.indexinglanguage.parser.CharStream;
+import java.util.Collections;
+import java.util.Map;
+
/**
* @author Simon Thoresen Hult
*/
@@ -13,13 +16,13 @@ public class ScriptParserContext {
private AnnotatorConfig annotatorConfig = new AnnotatorConfig();
private Linguistics linguistics;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private String defaultFieldName = null;
private CharStream inputStream = null;
- public ScriptParserContext(Linguistics linguistics, Embedder embedder) {
+ public ScriptParserContext(Linguistics linguistics, Map<String, Embedder> embedders) {
this.linguistics = linguistics;
- this.embedder = embedder;
+ this.embedders = embedders;
}
public AnnotatorConfig getAnnotatorConfig() {
@@ -40,8 +43,8 @@ public class ScriptParserContext {
return this;
}
- public Embedder getEmbedder() {
- return embedder;
+ public Map<String, Embedder> getEmbedders() {
+ return Collections.unmodifiableMap(embedders);
}
public String getDefaultFieldName() {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 0da9d907718..2e4bb701454 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -12,6 +12,10 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
/**
* Embeds a string in a tensor space using the configured Embedder component
*
@@ -20,6 +24,7 @@ import com.yahoo.tensor.TensorType;
public class EmbedExpression extends Expression {
private final Embedder embedder;
+ private final String embedderId;
/** The destination the embedding will be written to on the form [schema name].[field name] */
private String destination;
@@ -27,9 +32,28 @@ public class EmbedExpression extends Expression {
/** The target type we are embedding into. */
private TensorType targetType;
- public EmbedExpression(Embedder embedder) {
+ public EmbedExpression(Map<String, Embedder> embedders, String embedderId) {
super(DataType.STRING);
- this.embedder = embedder;
+ this.embedderId = embedderId;
+
+ boolean embedderIdProvided = embedderId != null && embedderId.length() > 0;
+
+ if (embedders.size() == 0) {
+ throw new IllegalStateException("No embedders provided"); // should never happen
+ }
+ else if (embedders.size() > 1 && ! embedderIdProvided) {
+ this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " +
+ "Valid embedders are " + validEmbedders(embedders));
+ }
+ else if (embedders.size() == 1 && ! embedderIdProvided) {
+ this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
+ }
+ else if ( ! embedders.containsKey(embedderId)) {
+ this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " +
+ "Valid embedders are " + validEmbedders(embedders));
+ } else {
+ this.embedder = embedders.get(embedderId);
+ }
}
@Override
@@ -71,7 +95,14 @@ public class EmbedExpression extends Expression {
}
@Override
- public String toString() { return "embed"; }
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("embed");
+ if (this.embedderId != null && this.embedderId.length() > 0) {
+ sb.append(" ").append(this.embedderId);
+ }
+ return sb.toString();
+ }
@Override
public int hashCode() { return 1; }
@@ -79,4 +110,11 @@ public class EmbedExpression extends Expression {
@Override
public boolean equals(Object o) { return o instanceof EmbedExpression; }
+ private static String validEmbedders(Map<String, Embedder> embedders) {
+ List<String> embedderIds = new ArrayList<>();
+ embedders.forEach((key, value) -> embedderIds.add(key));
+ embedderIds.sort(null);
+ return String.join(",", embedderIds);
+ }
+
}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java
index a5b62c73997..e5bf4711ad1 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java
@@ -15,6 +15,8 @@ import com.yahoo.vespa.indexinglanguage.parser.IndexingInput;
import com.yahoo.vespa.indexinglanguage.parser.ParseException;
import com.yahoo.vespa.objects.Selectable;
+import java.util.Map;
+
/**
* @author Simon Thoresen Hult
*/
@@ -191,11 +193,11 @@ public abstract class Expression extends Selectable {
/** Creates an expression with simple lingustics for testing */
public static Expression fromString(String expression) throws ParseException {
- return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
- public static Expression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException {
- return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression)));
+ public static Expression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
+ return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression)));
}
public static Expression newInstance(ScriptParserContext context) throws ParseException {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java
index d8e9cc4d923..c8e45f0f61a 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java
@@ -15,6 +15,7 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
+import java.util.Map;
/**
* @author Simon Thoresen Hult
@@ -92,11 +93,11 @@ public final class ScriptExpression extends ExpressionList<StatementExpression>
/** Creates an expression with simple lingustics for testing */
@SuppressWarnings("deprecation")
public static ScriptExpression fromString(String expression) throws ParseException {
- return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
- public static ScriptExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException {
- return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression)));
+ public static ScriptExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
+ return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression)));
}
public static ScriptExpression newInstance(ScriptParserContext config) throws ParseException {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
index 40aa0f58413..38157531ba2 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
@@ -14,6 +14,7 @@ import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
/**
* @author Simon Thoresen Hult
@@ -99,11 +100,11 @@ public final class StatementExpression extends ExpressionList<Expression> {
/** Creates an expression with simple lingustics for testing */
public static StatementExpression fromString(String expression) throws ParseException {
- return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
- public static StatementExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException {
- return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression)));
+ public static StatementExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
+ return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression)));
}
public static StatementExpression newInstance(ScriptParserContext config) throws ParseException {
diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj
index e6b21f7c07b..51bb9be1f8a 100644
--- a/indexinglanguage/src/main/javacc/IndexingParser.jj
+++ b/indexinglanguage/src/main/javacc/IndexingParser.jj
@@ -45,7 +45,7 @@ public class IndexingParser {
private String defaultFieldName;
private Linguistics linguistics;
- private Embedder embedder;
+ private Map<String, Embedder> embedders;
private AnnotatorConfig annotatorCfg;
public IndexingParser(String str) {
@@ -62,8 +62,8 @@ public class IndexingParser {
return this;
}
- public IndexingParser setEmbedder(Embedder embedder) {
- this.embedder = embedder;
+ public IndexingParser setEmbedders(Map<String, Embedder> embedders) {
+ this.embedders = embedders;
return this;
}
@@ -367,10 +367,13 @@ Expression echoExp() : { }
{ return new EchoExpression(); }
}
-Expression embedExp() : { }
+Expression embedExp() :
{
- ( <EMBED> )
- { return new EmbedExpression(embedder); }
+ String val = "";
+}
+{
+ ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] )
+ { return new EmbedExpression(embedders, val); }
}
Expression exactExp() : { }
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java
index 87c54fd7abd..28da9a71aac 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java
@@ -96,7 +96,7 @@ public class ScriptParserTestCase {
}
private static ScriptParserContext newContext(String input) {
- return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse).setInputStream(new IndexingInput(input));
+ return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse.asMap()).setInputStream(new IndexingInput(input));
}
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 27723c6649d..de31f6fcb1e 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -21,6 +21,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException;
import org.junit.Test;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.*;
@@ -175,37 +176,66 @@ public class ScriptTestCase {
@Test
public void testEmbed() throws ParseException {
- TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
- var expression = Expression.fromString("input myText | embed | attribute 'myTensor'",
- new SimpleLinguistics(),
- new MockEmbedder("myDocument.myTensor"));
-
- SimpleTestAdapter adapter = new SimpleTestAdapter();
- adapter.createField(new Field("myText", DataType.STRING));
- var tensorField = new Field("myTensor", new TensorDataType(tensorType));
- adapter.createField(tensorField);
- adapter.setValue("myText", new StringFieldValue("input text"));
- expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
-
- // Necessary to resolve output type
- VerificationContext verificationContext = new VerificationContext(adapter);
- assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass());
+ // Test parsing without knowledge of any embedders
+ String exp = "input myText | embed emb1 | attribute 'myTensor'";
+ Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
+
+ Map<String, Embedder> embedder = Map.of(
+ "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]")
+ );
+ testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "[1,2,0,0]");
+ testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, "[1,2,0,0]");
+ testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, "[1,2,0,0]");
+
+ Map<String, Embedder> embedders = Map.of(
+ "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]"),
+ "emb2", new MockEmbedder("myDocument.myTensor", "[3,4,5,0]")
+ );
+ testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "[1,2,0,0]");
+ testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, "[3,4,5,0]");
+
+ assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "[3,4,5,0]"),
+ "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1,emb2");
+ assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "[3,4,5,0]"),
+ "Can't find embedder 'emb3'. Valid embedders are emb1,emb2");
+ }
- ExecutionContext context = new ExecutionContext(adapter);
- context.setValue(new StringFieldValue("input text"));
- expression.execute(context);
- assertTrue(adapter.values.containsKey("myTensor"));
- assertEquals(Tensor.from(tensorType, "[7,3,0,0]"),
- ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get());
+ private void testEmbedStatement(String exp, Map<String, Embedder> embedders, String expected) {
+ try {
+ var expression = Expression.fromString(exp, new SimpleLinguistics(), embedders);
+ TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myText", DataType.STRING));
+ var tensorField = new Field("myTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+ adapter.setValue("myText", new StringFieldValue("input text"));
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass());
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(new StringFieldValue("input text"));
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("myTensor"));
+ assertEquals(Tensor.from(tensorType, expected),
+ ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get());
+ } catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
}
@SuppressWarnings("unchecked")
@Test
public void testArrayEmbed() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray", "[7,3,0,0]"));
+
TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'",
new SimpleLinguistics(),
- new MockEmbedder("myDocument.myTensorArray"));
+ embedders);
SimpleTestAdapter adapter = new SimpleTestAdapter();
adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
@@ -235,9 +265,11 @@ public class ScriptTestCase {
private static class MockEmbedder implements Embedder {
private final String expectedDestination;
+ private final String tensorString;
- public MockEmbedder(String expectedDestination) {
+ public MockEmbedder(String expectedDestination, String tensorString) {
this.expectedDestination = expectedDestination;
+ this.tensorString = tensorString;
}
@Override
@@ -248,9 +280,18 @@ public class ScriptTestCase {
@Override
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
assertEquals(expectedDestination, context.getDestination());
- return Tensor.from(tensorType, "[7,3,0,0]");
+ return Tensor.from(tensorType, tensorString);
}
}
+ private void assertThrows(Runnable r, String msg) {
+ try {
+ r.run();
+ fail();
+ } catch (IllegalStateException e) {
+ assertEquals(e.getMessage(), msg);
+ }
+ }
+
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java
index f6aa7e477a8..89170027c73 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java
@@ -19,7 +19,7 @@ public class DefaultFieldNameTestCase {
public void requireThatDefaultFieldNameIsAppliedWhenArgumentIsMissing() throws ParseException {
IndexingInput input = new IndexingInput("input");
InputExpression exp = (InputExpression)Expression.newInstance(new ScriptParserContext(new SimpleLinguistics(),
- Embedder.throwsOnUse)
+ Embedder.throwsOnUse.asMap())
.setInputStream(input)
.setDefaultFieldName("foo"));
assertEquals("foo", exp.getFieldName());
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java
index e333eea7001..7db026d43ee 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java
@@ -85,9 +85,9 @@ public class ExpressionTestCase {
private static void assertExpression(Class expectedClass, String str) throws ParseException {
Linguistics linguistics = new SimpleLinguistics();
- Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse);
+ Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse.asMap());
assertEquals(expectedClass, foo.getClass());
- Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse);
+ Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse.asMap());
assertEquals(foo.hashCode(), bar.hashCode());
assertEquals(foo, bar);
}
diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json
index 910056286ec..c3e489b8dd9 100644
--- a/linguistics/abi-spec.json
+++ b/linguistics/abi-spec.json
@@ -354,6 +354,7 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(java.lang.String)",
"public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
"public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)"
],
@@ -368,10 +369,13 @@
"abstract"
],
"methods": [
+ "public java.util.Map asMap()",
+ "public java.util.Map asMap(java.lang.String)",
"public abstract java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
"public abstract com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)"
],
"fields": [
+ "public static final java.lang.String defaultEmbedderId",
"public static final com.yahoo.language.process.Embedder throwsOnUse"
]
},
diff --git a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
index dd9c3847314..238698e898a 100644
--- a/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
+++ b/linguistics/src/main/java/com/yahoo/language/process/Embedder.java
@@ -3,10 +3,10 @@ package com.yahoo.language.process;
import com.yahoo.language.Language;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.List;
+import java.util.Map;
/**
* An embedder converts a text string to a tensor
@@ -15,9 +15,22 @@ import java.util.List;
*/
public interface Embedder {
+ /** Name of embedder when none is explicity given */
+ String defaultEmbedderId = "default";
+
/** An instance of this which throws IllegalStateException if attempted used */
Embedder throwsOnUse = new FailingEmbedder();
+ /** Returns this embedder instance as a map with the default embedder name */
+ default Map<String, Embedder> asMap() {
+ return asMap(defaultEmbedderId);
+ }
+
+ /** Returns this embedder instance as a map with the given name */
+ default Map<String, Embedder> asMap(String name) {
+ return Map.of(name, this);
+ }
+
/**
* Converts text into a list of token id's (a vector embedding)
*
@@ -82,14 +95,24 @@ public interface Embedder {
class FailingEmbedder implements Embedder {
+ private final String message;
+
+ public FailingEmbedder() {
+ this("No embedder has been configured");
+ }
+
+ public FailingEmbedder(String message) {
+ this.message = message;
+ }
+
@Override
public List<Integer> embed(String text, Context context) {
- throw new IllegalStateException("No embedder has been configured");
+ throw new IllegalStateException(message);
}
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
- throw new IllegalStateException("No embedder has been configured");
+ throw new IllegalStateException(message);
}
}