diff options
author | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:15 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-03-21 14:16:15 +0100 |
commit | c5e464f1a6da3a74113d775805187a547074a2da (patch) | |
tree | dab30afcde250b686d85472f9e2b46d28c9e2184 /indexinglanguage | |
parent | 24555fae4aac0dadde821cac0b7cf85321027bce (diff) |
Add embedder selection argument to indexing language
Diffstat (limited to 'indexinglanguage')
11 files changed, 141 insertions, 52 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java index 11756ae0907..2b4e0db699b 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java @@ -62,7 +62,7 @@ public final class ScriptParser { parser.setAnnotatorConfig(context.getAnnotatorConfig()); parser.setDefaultFieldName(context.getDefaultFieldName()); parser.setLinguistics(context.getLinguistcs()); - parser.setEmbedder(context.getEmbedder()); + parser.setEmbedders(context.getEmbedders()); try { return method.call(parser); } catch (ParseException e) { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java index 91c24a10e27..9edbed68871 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java @@ -6,6 +6,9 @@ import com.yahoo.language.process.Embedder; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.vespa.indexinglanguage.parser.CharStream; +import java.util.Collections; +import java.util.Map; + /** * @author Simon Thoresen Hult */ @@ -13,13 +16,13 @@ public class ScriptParserContext { private AnnotatorConfig annotatorConfig = new AnnotatorConfig(); private Linguistics linguistics; - private final Embedder embedder; + private final Map<String, Embedder> embedders; private String defaultFieldName = null; private CharStream inputStream = null; - public ScriptParserContext(Linguistics linguistics, Embedder embedder) { + public ScriptParserContext(Linguistics linguistics, Map<String, Embedder> embedders) { this.linguistics = linguistics; - this.embedder = embedder; + this.embedders = embedders; } public AnnotatorConfig getAnnotatorConfig() { @@ -40,8 +43,8 @@ public class ScriptParserContext { return this; } - public Embedder getEmbedder() { - return embedder; + public Map<String, Embedder> getEmbedders() { + return Collections.unmodifiableMap(embedders); } public String getDefaultFieldName() { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 0da9d907718..2e4bb701454 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -12,6 +12,10 @@ import com.yahoo.language.process.Embedder; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + /** * Embeds a string in a tensor space using the configured Embedder component * @@ -20,6 +24,7 @@ import com.yahoo.tensor.TensorType; public class EmbedExpression extends Expression { private final Embedder embedder; + private final String embedderId; /** The destination the embedding will be written to on the form [schema name].[field name] */ private String destination; @@ -27,9 +32,28 @@ public class EmbedExpression extends Expression { /** The target type we are embedding into. */ private TensorType targetType; - public EmbedExpression(Embedder embedder) { + public EmbedExpression(Map<String, Embedder> embedders, String embedderId) { super(DataType.STRING); - this.embedder = embedder; + this.embedderId = embedderId; + + boolean embedderIdProvided = embedderId != null && embedderId.length() > 0; + + if (embedders.size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } + else if (embedders.size() > 1 && ! embedderIdProvided) { + this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(embedders)); + } + else if (embedders.size() == 1 && ! embedderIdProvided) { + this.embedder = embedders.entrySet().stream().findFirst().get().getValue(); + } + else if ( ! embedders.containsKey(embedderId)) { + this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + } else { + this.embedder = embedders.get(embedderId); + } } @Override @@ -71,7 +95,14 @@ public class EmbedExpression extends Expression { } @Override - public String toString() { return "embed"; } + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("embed"); + if (this.embedderId != null && this.embedderId.length() > 0) { + sb.append(" ").append(this.embedderId); + } + return sb.toString(); + } @Override public int hashCode() { return 1; } @@ -79,4 +110,11 @@ public class EmbedExpression extends Expression { @Override public boolean equals(Object o) { return o instanceof EmbedExpression; } + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); + } + } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java index a5b62c73997..e5bf4711ad1 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java @@ -15,6 +15,8 @@ import com.yahoo.vespa.indexinglanguage.parser.IndexingInput; import com.yahoo.vespa.indexinglanguage.parser.ParseException; import com.yahoo.vespa.objects.Selectable; +import java.util.Map; + /** * @author Simon Thoresen Hult */ @@ -191,11 +193,11 @@ public abstract class Expression extends Selectable { /** Creates an expression with simple lingustics for testing */ public static Expression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static Expression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static Expression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static Expression newInstance(ScriptParserContext context) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java index d8e9cc4d923..c8e45f0f61a 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java @@ -15,6 +15,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Iterator; +import java.util.Map; /** * @author Simon Thoresen Hult @@ -92,11 +93,11 @@ public final class ScriptExpression extends ExpressionList<StatementExpression> /** Creates an expression with simple lingustics for testing */ @SuppressWarnings("deprecation") public static ScriptExpression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static ScriptExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static ScriptExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static ScriptExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java index 40aa0f58413..38157531ba2 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java @@ -14,6 +14,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Map; /** * @author Simon Thoresen Hult @@ -99,11 +100,11 @@ public final class StatementExpression extends ExpressionList<Expression> { /** Creates an expression with simple lingustics for testing */ public static StatementExpression fromString(String expression) throws ParseException { - return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse); + return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); } - public static StatementExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression))); + public static StatementExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); } public static StatementExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index e6b21f7c07b..51bb9be1f8a 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -45,7 +45,7 @@ public class IndexingParser { private String defaultFieldName; private Linguistics linguistics; - private Embedder embedder; + private Map<String, Embedder> embedders; private AnnotatorConfig annotatorCfg; public IndexingParser(String str) { @@ -62,8 +62,8 @@ public class IndexingParser { return this; } - public IndexingParser setEmbedder(Embedder embedder) { - this.embedder = embedder; + public IndexingParser setEmbedders(Map<String, Embedder> embedders) { + this.embedders = embedders; return this; } @@ -367,10 +367,13 @@ Expression echoExp() : { } { return new EchoExpression(); } } -Expression embedExp() : { } +Expression embedExp() : { - ( <EMBED> ) - { return new EmbedExpression(embedder); } + String val = ""; +} +{ + ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] ) + { return new EmbedExpression(embedders, val); } } Expression exactExp() : { } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java index 87c54fd7abd..28da9a71aac 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java @@ -96,7 +96,7 @@ public class ScriptParserTestCase { } private static ScriptParserContext newContext(String input) { - return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse).setInputStream(new IndexingInput(input)); + return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse.asMap()).setInputStream(new IndexingInput(input)); } } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 27723c6649d..de31f6fcb1e 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -21,6 +21,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException; import org.junit.Test; import java.util.List; +import java.util.Map; import static org.junit.Assert.*; @@ -175,37 +176,66 @@ public class ScriptTestCase { @Test public void testEmbed() throws ParseException { - TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); - var expression = Expression.fromString("input myText | embed | attribute 'myTensor'", - new SimpleLinguistics(), - new MockEmbedder("myDocument.myTensor")); - - SimpleTestAdapter adapter = new SimpleTestAdapter(); - adapter.createField(new Field("myText", DataType.STRING)); - var tensorField = new Field("myTensor", new TensorDataType(tensorType)); - adapter.createField(tensorField); - adapter.setValue("myText", new StringFieldValue("input text")); - expression.setStatementOutput(new DocumentType("myDocument"), tensorField); - - // Necessary to resolve output type - VerificationContext verificationContext = new VerificationContext(adapter); - assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + // Test parsing without knowledge of any embedders + String exp = "input myText | embed emb1 | attribute 'myTensor'"; + Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + + Map<String, Embedder> embedder = Map.of( + "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]") + ); + testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "[1,2,0,0]"); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, "[1,2,0,0]"); + testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, "[1,2,0,0]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]"), + "emb2", new MockEmbedder("myDocument.myTensor", "[3,4,5,0]") + ); + testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "[1,2,0,0]"); + testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, "[3,4,5,0]"); + + assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "[3,4,5,0]"), + "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1,emb2"); + assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "[3,4,5,0]"), + "Can't find embedder 'emb3'. Valid embedders are emb1,emb2"); + } - ExecutionContext context = new ExecutionContext(adapter); - context.setValue(new StringFieldValue("input text")); - expression.execute(context); - assertTrue(adapter.values.containsKey("myTensor")); - assertEquals(Tensor.from(tensorType, "[7,3,0,0]"), - ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); + private void testEmbedStatement(String exp, Map<String, Embedder> embedders, String expected) { + try { + var expression = Expression.fromString(exp, new SimpleLinguistics(), embedders); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var tensorField = new Field("myTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + adapter.setValue("myText", new StringFieldValue("input text")); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(new StringFieldValue("input text")); + expression.execute(context); + assertTrue(adapter.values.containsKey("myTensor")); + assertEquals(Tensor.from(tensorType, expected), + ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get()); + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } } @SuppressWarnings("unchecked") @Test public void testArrayEmbed() throws ParseException { + Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray", "[7,3,0,0]")); + TensorType tensorType = TensorType.fromSpec("tensor(d[4])"); var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'", new SimpleLinguistics(), - new MockEmbedder("myDocument.myTensorArray")); + embedders); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING))); @@ -235,9 +265,11 @@ public class ScriptTestCase { private static class MockEmbedder implements Embedder { private final String expectedDestination; + private final String tensorString; - public MockEmbedder(String expectedDestination) { + public MockEmbedder(String expectedDestination, String tensorString) { this.expectedDestination = expectedDestination; + this.tensorString = tensorString; } @Override @@ -248,9 +280,18 @@ public class ScriptTestCase { @Override public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedDestination, context.getDestination()); - return Tensor.from(tensorType, "[7,3,0,0]"); + return Tensor.from(tensorType, tensorString); } } + private void assertThrows(Runnable r, String msg) { + try { + r.run(); + fail(); + } catch (IllegalStateException e) { + assertEquals(e.getMessage(), msg); + } + } + } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java index f6aa7e477a8..89170027c73 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java @@ -19,7 +19,7 @@ public class DefaultFieldNameTestCase { public void requireThatDefaultFieldNameIsAppliedWhenArgumentIsMissing() throws ParseException { IndexingInput input = new IndexingInput("input"); InputExpression exp = (InputExpression)Expression.newInstance(new ScriptParserContext(new SimpleLinguistics(), - Embedder.throwsOnUse) + Embedder.throwsOnUse.asMap()) .setInputStream(input) .setDefaultFieldName("foo")); assertEquals("foo", exp.getFieldName()); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java index e333eea7001..7db026d43ee 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java @@ -85,9 +85,9 @@ public class ExpressionTestCase { private static void assertExpression(Class expectedClass, String str) throws ParseException { Linguistics linguistics = new SimpleLinguistics(); - Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse); + Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse.asMap()); assertEquals(expectedClass, foo.getClass()); - Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse); + Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse.asMap()); assertEquals(foo.hashCode(), bar.hashCode()); assertEquals(foo, bar); } |