diff options
author | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:15 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:15 +0100 |
commit | c5e464f1a6da3a74113d775805187a547074a2da (patch) | |
tree | dab30afcde250b686d85472f9e2b46d28c9e2184 | |
parent | 24555fae4aac0dadde821cac0b7cf85321027bce (diff) |
Add embedder selection argument to indexing language
19 files changed, 187 insertions, 81 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java index 621e7ce8571..ac6524dac92 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java @@ -410,14 +410,14 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, return wasConfiguredToDoAttributing; } - /** Parse an indexing expression which will use the simple linguistics implementatino suitable for testing */ + /** Parse an indexing expression which will use the simple linguistics implementation suitable for testing */ public void parseIndexingScript(String script) { - parseIndexingScript(script, new SimpleLinguistics(), Embedder.throwsOnUse); + parseIndexingScript(script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public void parseIndexingScript(String script, Linguistics linguistics, Embedder embedder) { + public void parseIndexingScript(String script, Linguistics linguistics, Map<String, Embedder> embedders) { try { - ScriptParserContext config = new ScriptParserContext(linguistics, embedder); + ScriptParserContext config = new ScriptParserContext(linguistics, embedders); config.setInputStream(new IndexingInput(script)); setIndexingScript(ScriptExpression.newInstance(config)); } catch (ParseException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java index a5f5f961ab5..cdd3cc386a4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/fieldoperation/IndexingOperation.java @@ -13,6 +13,8 @@ import com.yahoo.vespa.indexinglanguage.expressions.StatementExpression; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.yolean.Exceptions; +import java.util.Map; + /** * @author Einar M R Rosenvinge */ @@ -32,13 +34,13 @@ public class IndexingOperation implements FieldOperation { /** Creates an indexing operation which will use the simple linguistics implementation suitable for testing */ public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine) throws ParseException { - return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine, - Linguistics linguistics, Embedder embedder) + Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { - ScriptParserContext config = new ScriptParserContext(linguistics, embedder); + ScriptParserContext config = new ScriptParserContext(linguistics, embedders); config.setAnnotatorConfig(new AnnotatorConfig()); config.setInputStream(input); ScriptExpression exp; diff --git a/config-model/src/main/javacc/IntermediateParser.jj b/config-model/src/main/javacc/IntermediateParser.jj index ba955f071b2..8a4798d6f74 100644 --- a/config-model/src/main/javacc/IntermediateParser.jj +++ b/config-model/src/main/javacc/IntermediateParser.jj @@ -81,7 +81,7 @@ public class IntermediateParser { */ @SuppressWarnings("deprecation") private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException { - return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse); + return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } /** @@ -90,13 +90,13 @@ public class IntermediateParser { * @param multiline Whether or not to allow multi-line expressions. * @param linguistics What to use for tokenizing. */ - private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Embedder embedder) throws ParseException { + private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { SimpleCharStream input = (SimpleCharStream)token_source.input_stream; if (token.next != null) { input.backup(token.next.image.length()); } try { - return IndexingOperation.fromStream(input, multiline, linguistics, embedder); + return IndexingOperation.fromStream(input, multiline, linguistics, embedders); } finally { token.next = null; jj_ntk = -1; diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj index 018531616fb..7253522cfd6 100644 --- a/config-model/src/main/javacc/SDParser.jj +++ b/config-model/src/main/javacc/SDParser.jj @@ -112,7 +112,7 @@ public class SDParser { */ @SuppressWarnings("deprecation") private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException { - return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse); + return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } /** @@ -121,13 +121,13 @@ public class SDParser { * @param multiline Whether or not to allow multi-line expressions. * @param linguistics What to use for tokenizing. */ - private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Embedder embedder) throws ParseException { + private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { SimpleCharStream input = (SimpleCharStream)token_source.input_stream; if (token.next != null) { input.backup(token.next.image.length()); } try { - return IndexingOperation.fromStream(input, multiline, linguistics, embedder); + return IndexingOperation.fromStream(input, multiline, linguistics, embedders); } finally { token.next = null; jj_ntk = -1; diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java index 7b553383daf..87c78445b13 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java @@ -7,6 +7,7 @@ import com.google.inject.Inject; import com.yahoo.component.chain.dependencies.After; import com.yahoo.component.chain.dependencies.Before; import com.yahoo.component.chain.dependencies.Provides; +import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.docproc.DocumentProcessor; import com.yahoo.docproc.Processing; import com.yahoo.document.Document; @@ -15,18 +16,20 @@ import com.yahoo.document.DocumentPut; import com.yahoo.document.DocumentRemove; import com.yahoo.document.DocumentType; import com.yahoo.document.DocumentTypeManager; -import com.yahoo.document.DocumentTypeManagerConfigurer; import com.yahoo.document.DocumentUpdate; -import com.yahoo.document.config.DocumentmanagerConfig; import com.yahoo.language.Linguistics; -import java.util.logging.Level; - import com.yahoo.language.process.Embedder; +import com.yahoo.language.provider.DefaultEmbedderProvider; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.indexinglanguage.AdapterFactory; import com.yahoo.vespa.indexinglanguage.SimpleAdapterFactory; import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import java.util.Map; +import java.util.logging.Level; +import java.util.stream.Collectors; + + /** * @author Simon Thoresen Hult */ @@ -55,9 +58,9 @@ public class IndexingProcessor extends DocumentProcessor { public IndexingProcessor(DocumentTypeManager documentTypeManager, IlscriptsConfig ilscriptsConfig, Linguistics linguistics, - Embedder embedder) { + ComponentRegistry<Embedder> embedders) { docTypeMgr = documentTypeManager; - scriptMgr = new ScriptManager(docTypeMgr, ilscriptsConfig, linguistics, embedder); + scriptMgr = new ScriptManager(docTypeMgr, ilscriptsConfig, linguistics, toMap(embedders)); adapterFactory = new SimpleAdapterFactory(new ExpressionSelector()); } @@ -128,4 +131,14 @@ public class IndexingProcessor extends DocumentProcessor { out.add(prev); } + private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) { + var map = embedders.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + if (map.size() > 1) { + map.remove(DefaultEmbedderProvider.class.getName()); + // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. + } + return map; + } + } diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java index 63c6d6c4bb5..de3a429e357 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java @@ -28,12 +28,12 @@ public class ScriptManager { private final Map<String, Map<String, DocumentScript>> documentFieldScripts; private final DocumentTypeManager docTypeMgr; - public ScriptManager(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, Embedder embedder) { + public ScriptManager(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, + Map<String, Embedder> embedders) { this.docTypeMgr = docTypeMgr; - documentFieldScripts = createScriptsMap(docTypeMgr, config, linguistics, embedder); + documentFieldScripts = createScriptsMap(docTypeMgr, config, linguistics, embedders); } - private Map<String, DocumentScript> getScripts(DocumentType inputType) { Map<String, DocumentScript> scripts = documentFieldScripts.get(inputType.getName()); if (scripts != null) { @@ -75,9 +75,9 @@ public class ScriptManager { private static Map<String, Map<String, DocumentScript>> createScriptsMap(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, - Embedder embedder) { + Map<String, Embedder> embedders) { Map<String, Map<String, DocumentScript>> documentFieldScripts = new HashMap<>(config.ilscript().size()); - ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedder); + ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders); parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences()); parserContext.getAnnotatorConfig().setMaxTokenLength(config.fieldmatchmaxlength()); diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java index 13f9ea1a8c8..76f4578ac87 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.docprocs.indexing; +import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.docproc.Processing; import com.yahoo.document.Document; @@ -128,6 +129,6 @@ public class IndexingProcessorTestCase { return new IndexingProcessor(new DocumentTypeManager(ConfigGetter.getConfig(DocumentmanagerConfig.class, configId)), ConfigGetter.getConfig(IlscriptsConfig.class, configId), new SimpleLinguistics(), - Embedder.throwsOnUse); + new ComponentRegistry<Embedder>()); } } diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java index a35dd0da4f3..4a7e643fb0a 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException; import org.junit.Test; import java.util.Iterator; +import java.util.Map; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; @@ -28,7 +29,7 @@ public class ScriptManagerTestCase { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newssummary") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newsarticle"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -42,7 +43,7 @@ public class ScriptManagerTestCase { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newsarticle") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newssummary"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -50,14 +51,14 @@ public class ScriptManagerTestCase { @Test public void requireThatEmptyConfigurationDoesNotThrow() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @Test public void requireThatUnknownDocumentTypeReturnsNull() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); for (Iterator<DocumentType> it = typeMgr.documentTypeIterator(); it.hasNext(); ) { assertNull(scriptMgr.getScript(it.next())); } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java index 11756ae0907..2b4e0db699b 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java @@ -62,7 +62,7 @@ public final class ScriptParser { parser.setAnnotatorConfig(context.getAnnotatorConfig()); parser.setDefaultFieldName(context.getDefaultFieldName()); parser.setLinguistics(context.getLinguistcs()); - parser.setEmbedder(context.getEmbedder()); + parser.setEmbedders(context.getEmbedders()); try { return method.call(parser); } catch (ParseException e) { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java index 91c24a10e27..9edbed68871 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java @@ -6,6 +6,9 @@ import com.yahoo.language.process.Embedder; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.vespa.indexinglanguage.parser.CharStream; +import java.util.Collections; +import java.util.Map; + /** * @author Simon Thoresen Hult */ @@ -13,13 +16,13 @@ public class ScriptParserContext { private AnnotatorConfig annotatorConfig = new AnnotatorConfig(); private Linguistics linguistics; - private final Embedder embedder; + private final Map<String, Embedder> embedders; private String defaultFieldName = null; private CharStream inputStream = null; - public ScriptParserContext(Linguistics linguistics, Embedder embedder) { + public ScriptParserContext(Linguistics linguistics, Map<String, Embedder> embedders) { this.linguistics = linguistics; - this.embedder = embedder; + this.embedders = embedders; } public AnnotatorConfig getAnnotatorConfig() { @@ -40,8 +43,8 @@ public class ScriptParserContext { return this; } - public Embedder getEmbedder() { - return embedder; + public Map<String, Embedder> getEmbedders() { + return Collections.unmodifiableMap(embedders); } public String getDefaultFieldName() { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 0da9d907718..2e4bb701454 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -12,6 +12,10 @@ import com.yahoo.language.process.Embedder; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + /** * Embeds a string in a tensor space using the configured Embedder component * @@ -20,6 +24,7 @@ import com.yahoo.tensor.TensorType; public class EmbedExpression extends Expression { private final Embedder embedder; + private final String embedderId; /** The destination the embedding will be written to on the form [schema name].[field name] */ private String destination; @@ -27,9 +32,28 @@ public class EmbedExpression extends Expression { /** The target type we are embedding into. */ private TensorType targetType; - public EmbedExpression(Embedder embedder) { + public EmbedExpression(Map<String, Embedder> embedders, String embedderId) { super(DataType.STRING); - this.embedder = embedder; + this.embedderId = embedderId; + + boolean embedderIdProvided = embedderId != null && embedderId.length() > 0; + + if (embedders.size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } + else if (embedders.size() > 1 && ! embedderIdProvided) { + this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(embedders)); + } + else if (embedders.size() == 1 && ! embedderIdProvided) { + this.embedder = embedders.entrySet().stream().findFirst().get().getValue(); + } + else if ( ! embedders.containsKey(embedderId)) { + this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + } else { + this.embedder = embedders.get(embedderId); + } } @Override @@ -71,7 +95,14 @@ public class EmbedExpression extends Expression { } @Override - public String toString() { return "embed"; } + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("embed"); + if (this.embedderId != null && this.embedderId.length() > 0) { + sb.append(" ").append(this.embedderId); + } + return sb.toString(); + } @Override public int hashCode() { return 1; } @@ -79,4 +110,11 @@ public class EmbedExpression extends Expression { @Override public boolean equals(Object o) { return o instanceof EmbedExpression; } + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); + } + } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java index a5b62c73997..e5bf4711ad1 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java @@ -15,6 +15,8 @@ import com.yahoo.vespa.indexinglanguage.parser.IndexingInput; import com.yahoo.vespa.indexinglanguage.parser.ParseException; import com.yahoo.vespa.objects.Selectable; +import java.util.Map; + /** * @author Simon Thoresen Hult */ @@ -191,11 +193,11 @@ public abstract class Expression extends Selectable { /** Creates an expression with simple lingustics for testing */ public static Expression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static Expression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static Expression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static Expression newInstance(ScriptParserContext context) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java index d8e9cc4d923..c8e45f0f61a 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java @@ -15,6 +15,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; +import java.util.Map; /** * @author Simon Thoresen Hult @@ -92,11 +93,11 @@ public final class ScriptExpression extends ExpressionList<StatementExpression> /** Creates an expression with simple lingustics for testing */ @SuppressWarnings("deprecation") public static ScriptExpression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static ScriptExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static ScriptExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static ScriptExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java index 40aa0f58413..38157531ba2 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java @@ -14,6 +14,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Map; /** * @author Simon Thoresen Hult @@ -99,11 +100,11 @@ public final class StatementExpression extends ExpressionList<Expression> { /** Creates an expression with simple lingustics for testing */ public static StatementExpression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static StatementExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static StatementExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static StatementExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index e6b21f7c07b..51bb9be1f8a 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -45,7 +45,7 @@ public class IndexingParser { private String defaultFieldName; private Linguistics linguistics; - private Embedder embedder; + private Map<String, Embedder> embedders; private AnnotatorConfig annotatorCfg; public IndexingParser(String str) { @@ -62,8 +62,8 @@ public class IndexingParser { return this; } - public IndexingParser setEmbedder(Embedder embedder) { - this.embedder = embedder; + public IndexingParser setEmbedders(Map<String, Embedder> embedders) { + this.embedders = embedders; return this; } @@ -367,10 +367,13 @@ Expression echoExp() : { } { return new EchoExpression(); } } -Expression embedExp() : { } +Expression embedExp() : { - ( <EMBED> ) - { return new EmbedExpression(embedder); } + String val = ""; +} +{ + ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] ) + { return new EmbedExpression(embedders, val); } } Expression exactExp() : { } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java index 87c54fd7abd..28da9a71aac 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java @@ -96,7 +96,7 @@ public class ScriptParserTestCase { } private static ScriptParserContext newContext(String input) { - return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse).setInputStream(new IndexingInput(input)); + return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse.asMap()).setInputStream(new IndexingInput(input)); } } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 27723c6649d..de31f6fcb1e 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -21,6 +21,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException; import org.junit.Test; import java.util.List; +import java.util.Map; import static org.junit.Assert.*; @@ -175,37 +176,66 @@ public class ScriptTestCase { @Test public void testEmbed() throws ParseException { - TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); - var expression = Expression.fromString("input myText | embed | attribute 'myTensor'", - new SimpleLinguistics(), - new MockEmbedder("myDocument.myTensor")); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myText", DataType.STRING)); - var tensorField = new Field("myTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - adapter.setValue("myText", new StringFieldValue("input text")); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + // Test parsing without knowledge of any embedders + String exp = "input myText | embed emb1 | attribute 'myTensor'"; + Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + + Map<String, Embedder> embedder = Map.of( + "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]") + ); + testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "[1,2,0,0]"); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, "[1,2,0,0]"); + testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, "[1,2,0,0]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]"), + "emb2", new MockEmbedder("myDocument.myTensor", "[3,4,5,0]") + ); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "[1,2,0,0]"); + testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, "[3,4,5,0]"); + + assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "[3,4,5,0]"), + "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1,emb2"); + assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "[3,4,5,0]"), + "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); + } - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(new StringFieldValue("input text")); - expression.execute(context); - assertTrue(adapter.values.containsKey("myTensor")); - assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), - ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); + private void testEmbedStatement(String exp, Map<String, Embedder> embedders, String expected) { + try { + var expression = Expression.fromString(exp, new SimpleLinguistics(), embedders); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var tensorField = new Field("myTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + adapter.setValue("myText", new StringFieldValue("input text")); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(new StringFieldValue("input text")); + expression.execute(context); + assertTrue(adapter.values.containsKey("myTensor")); + assertEquals(Tensor.from(tensorType, expected), + ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } } @SuppressWarnings("unchecked") @Test public void testArrayEmbed() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray", "[7,3,0,0]")); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", new SimpleLinguistics(), - new MockEmbedder("myDocument.myTensorArray")); + embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -235,9 +265,11 @@ public class ScriptTestCase { private static class MockEmbedder implements Embedder { private final String expectedDestination; + private final String tensorString; - public MockEmbedder(String expectedDestination) { + public MockEmbedder(String expectedDestination, String tensorString) { this.expectedDestination = expectedDestination; + this.tensorString = tensorString; } @Override @@ -248,9 +280,18 @@ public class ScriptTestCase { @Override public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedDestination, context.getDestination()); - return Tensor.from(tensorType, "[7,3,0,0]"); + return Tensor.from(tensorType, tensorString); } } + private void assertThrows(Runnable r, String msg) { + try { + r.run(); + fail(); + } catch (IllegalStateException e) { + assertEquals(e.getMessage(), msg); + } + } + } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java index f6aa7e477a8..89170027c73 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java @@ -19,7 +19,7 @@ public class DefaultFieldNameTestCase { public void requireThatDefaultFieldNameIsAppliedWhenArgumentIsMissing() throws ParseException { IndexingInput input = new IndexingInput("input"); InputExpression exp = (InputExpression)Expression.newInstance(new ScriptParserContext(new SimpleLinguistics(), - Embedder.throwsOnUse) + Embedder.throwsOnUse.asMap()) .setInputStream(input) .setDefaultFieldName("foo")); assertEquals("foo", exp.getFieldName()); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java index e333eea7001..7db026d43ee 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java @@ -85,9 +85,9 @@ public class ExpressionTestCase { private static void assertExpression(Class expectedClass, String str) throws ParseException { Linguistics linguistics = new SimpleLinguistics(); - Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse); + Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse.asMap()); assertEquals(expectedClass, foo.getClass()); - Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse); + Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse.asMap()); assertEquals(foo.hashCode(), bar.hashCode()); assertEquals(foo, bar); } |