summaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-03-21 14:16:15 +0100
committerLester Solbakken <lesters@oath.com>2022-03-21 14:16:15 +0100
commitc5e464f1a6da3a74113d775805187a547074a2da (patch)
treedab30afcde250b686d85472f9e2b46d28c9e2184 /indexinglanguage
parent24555fae4aac0dadde821cac0b7cf85321027bce (diff)
Add embedder selection argument to indexing language
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java2
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java13
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java44
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java8
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java7
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java7
-rw-r--r--indexinglanguage/src/main/javacc/IndexingParser.jj15
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java2
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java89
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java2
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java4
11 files changed, 141 insertions, 52 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java
index 11756ae0907..2b4e0db699b 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java
@@ -62,7 +62,7 @@ public final class ScriptParser {
parser.setAnnotatorConfig(context.getAnnotatorConfig());
parser.setDefaultFieldName(context.getDefaultFieldName());
parser.setLinguistics(context.getLinguistcs());
- parser.setEmbedder(context.getEmbedder());
+ parser.setEmbedders(context.getEmbedders());
try {
return method.call(parser);
} catch (ParseException e) {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java
index 91c24a10e27..9edbed68871 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java
@@ -6,6 +6,9 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig;
import com.yahoo.vespa.indexinglanguage.parser.CharStream;
+import java.util.Collections;
+import java.util.Map;
+
/**
* @author Simon Thoresen Hult
*/
@@ -13,13 +16,13 @@ public class ScriptParserContext {
private AnnotatorConfig annotatorConfig = new AnnotatorConfig();
private Linguistics linguistics;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private String defaultFieldName = null;
private CharStream inputStream = null;
- public ScriptParserContext(Linguistics linguistics, Embedder embedder) {
+ public ScriptParserContext(Linguistics linguistics, Map<String, Embedder> embedders) {
this.linguistics = linguistics;
- this.embedder = embedder;
+ this.embedders = embedders;
}
public AnnotatorConfig getAnnotatorConfig() {
@@ -40,8 +43,8 @@ public class ScriptParserContext {
return this;
}
- public Embedder getEmbedder() {
- return embedder;
+ public Map<String, Embedder> getEmbedders() {
+ return Collections.unmodifiableMap(embedders);
}
public String getDefaultFieldName() {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 0da9d907718..2e4bb701454 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -12,6 +12,10 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
/**
* Embeds a string in a tensor space using the configured Embedder component
*
@@ -20,6 +24,7 @@ import com.yahoo.tensor.TensorType;
public class EmbedExpression extends Expression {
private final Embedder embedder;
+ private final String embedderId;
/** The destination the embedding will be written to on the form [schema name].[field name] */
private String destination;
@@ -27,9 +32,28 @@ public class EmbedExpression extends Expression {
/** The target type we are embedding into. */
private TensorType targetType;
- public EmbedExpression(Embedder embedder) {
+ public EmbedExpression(Map<String, Embedder> embedders, String embedderId) {
super(DataType.STRING);
- this.embedder = embedder;
+ this.embedderId = embedderId;
+
+ boolean embedderIdProvided = embedderId != null && embedderId.length() > 0;
+
+ if (embedders.size() == 0) {
+ throw new IllegalStateException("No embedders provided"); // should never happen
+ }
+ else if (embedders.size() > 1 && ! embedderIdProvided) {
+ this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " +
+ "Valid embedders are " + validEmbedders(embedders));
+ }
+ else if (embedders.size() == 1 && ! embedderIdProvided) {
+ this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
+ }
+ else if ( ! embedders.containsKey(embedderId)) {
+ this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " +
+ "Valid embedders are " + validEmbedders(embedders));
+ } else {
+ this.embedder = embedders.get(embedderId);
+ }
}
@Override
@@ -71,7 +95,14 @@ public class EmbedExpression extends Expression {
}
@Override
- public String toString() { return "embed"; }
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("embed");
+ if (this.embedderId != null && this.embedderId.length() > 0) {
+ sb.append(" ").append(this.embedderId);
+ }
+ return sb.toString();
+ }
@Override
public int hashCode() { return 1; }
@@ -79,4 +110,11 @@ public class EmbedExpression extends Expression {
@Override
public boolean equals(Object o) { return o instanceof EmbedExpression; }
+ private static String validEmbedders(Map<String, Embedder> embedders) {
+ List<String> embedderIds = new ArrayList<>();
+ embedders.forEach((key, value) -> embedderIds.add(key));
+ embedderIds.sort(null);
+ return String.join(",", embedderIds);
+ }
+
}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java
index a5b62c73997..e5bf4711ad1 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java
@@ -15,6 +15,8 @@ import com.yahoo.vespa.indexinglanguage.parser.IndexingInput;
import com.yahoo.vespa.indexinglanguage.parser.ParseException;
import com.yahoo.vespa.objects.Selectable;
+import java.util.Map;
+
/**
* @author Simon Thoresen Hult
*/
@@ -191,11 +193,11 @@ public abstract class Expression extends Selectable {
/** Creates an expression with simple lingustics for testing */
public static Expression fromString(String expression) throws ParseException {
- return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
- public static Expression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException {
- return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression)));
+ public static Expression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
+ return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression)));
}
public static Expression newInstance(ScriptParserContext context) throws ParseException {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java
index d8e9cc4d923..c8e45f0f61a 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java
@@ -15,6 +15,7 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
+import java.util.Map;
/**
* @author Simon Thoresen Hult
@@ -92,11 +93,11 @@ public final class ScriptExpression extends ExpressionList<StatementExpression>
/** Creates an expression with simple lingustics for testing */
@SuppressWarnings("deprecation")
public static ScriptExpression fromString(String expression) throws ParseException {
- return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
- public static ScriptExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException {
- return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression)));
+ public static ScriptExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
+ return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression)));
}
public static ScriptExpression newInstance(ScriptParserContext config) throws ParseException {
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
index 40aa0f58413..38157531ba2 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java
@@ -14,6 +14,7 @@ import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
+import java.util.Map;
/**
* @author Simon Thoresen Hult
@@ -99,11 +100,11 @@ public final class StatementExpression extends ExpressionList<Expression> {
/** Creates an expression with simple lingustics for testing */
public static StatementExpression fromString(String expression) throws ParseException {
- return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse);
+ return fromString(expression, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
}
- public static StatementExpression fromString(String expression, Linguistics linguistics, Embedder embedder) throws ParseException {
- return newInstance(new ScriptParserContext(linguistics, embedder).setInputStream(new IndexingInput(expression)));
+ public static StatementExpression fromString(String expression, Linguistics linguistics, Map<String, Embedder> embedders) throws ParseException {
+ return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression)));
}
public static StatementExpression newInstance(ScriptParserContext config) throws ParseException {
diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj
index e6b21f7c07b..51bb9be1f8a 100644
--- a/indexinglanguage/src/main/javacc/IndexingParser.jj
+++ b/indexinglanguage/src/main/javacc/IndexingParser.jj
@@ -45,7 +45,7 @@ public class IndexingParser {
private String defaultFieldName;
private Linguistics linguistics;
- private Embedder embedder;
+ private Map<String, Embedder> embedders;
private AnnotatorConfig annotatorCfg;
public IndexingParser(String str) {
@@ -62,8 +62,8 @@ public class IndexingParser {
return this;
}
- public IndexingParser setEmbedder(Embedder embedder) {
- this.embedder = embedder;
+ public IndexingParser setEmbedders(Map<String, Embedder> embedders) {
+ this.embedders = embedders;
return this;
}
@@ -367,10 +367,13 @@ Expression echoExp() : { }
{ return new EchoExpression(); }
}
-Expression embedExp() : { }
+Expression embedExp() :
{
- ( <EMBED> )
- { return new EmbedExpression(embedder); }
+ String val = "";
+}
+{
+ ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] )
+ { return new EmbedExpression(embedders, val); }
}
Expression exactExp() : { }
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java
index 87c54fd7abd..28da9a71aac 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java
@@ -96,7 +96,7 @@ public class ScriptParserTestCase {
}
private static ScriptParserContext newContext(String input) {
- return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse).setInputStream(new IndexingInput(input));
+ return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse.asMap()).setInputStream(new IndexingInput(input));
}
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 27723c6649d..de31f6fcb1e 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -21,6 +21,7 @@ import com.yahoo.vespa.indexinglanguage.parser.ParseException;
import org.junit.Test;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.*;
@@ -175,37 +176,66 @@ public class ScriptTestCase {
@Test
public void testEmbed() throws ParseException {
- TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
- var expression = Expression.fromString("input myText | embed | attribute 'myTensor'",
- new SimpleLinguistics(),
- new MockEmbedder("myDocument.myTensor"));
-
- SimpleTestAdapter adapter = new SimpleTestAdapter();
- adapter.createField(new Field("myText", DataType.STRING));
- var tensorField = new Field("myTensor", new TensorDataType(tensorType));
- adapter.createField(tensorField);
- adapter.setValue("myText", new StringFieldValue("input text"));
- expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
-
- // Necessary to resolve output type
- VerificationContext verificationContext = new VerificationContext(adapter);
- assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass());
+ // Test parsing without knowledge of any embedders
+ String exp = "input myText | embed emb1 | attribute 'myTensor'";
+ Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
+
+ Map<String, Embedder> embedder = Map.of(
+ "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]")
+ );
+ testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder, "[1,2,0,0]");
+ testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedder, "[1,2,0,0]");
+ testEmbedStatement("input myText | embed 'emb1' | attribute 'myTensor'", embedder, "[1,2,0,0]");
+
+ Map<String, Embedder> embedders = Map.of(
+ "emb1", new MockEmbedder("myDocument.myTensor", "[1,2,0,0]"),
+ "emb2", new MockEmbedder("myDocument.myTensor", "[3,4,5,0]")
+ );
+ testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders, "[1,2,0,0]");
+ testEmbedStatement("input myText | embed emb2 | attribute 'myTensor'", embedders, "[3,4,5,0]");
+
+ assertThrows(() -> testEmbedStatement("input myText | embed | attribute 'myTensor'", embedders, "[3,4,5,0]"),
+ "Multiple embedders are provided but no embedder id is given. Valid embedders are emb1,emb2");
+ assertThrows(() -> testEmbedStatement("input myText | embed emb3 | attribute 'myTensor'", embedders, "[3,4,5,0]"),
+ "Can't find embedder 'emb3'. Valid embedders are emb1,emb2");
+ }
- ExecutionContext context = new ExecutionContext(adapter);
- context.setValue(new StringFieldValue("input text"));
- expression.execute(context);
- assertTrue(adapter.values.containsKey("myTensor"));
- assertEquals(Tensor.from(tensorType, "[7,3,0,0]"),
- ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get());
+ private void testEmbedStatement(String exp, Map<String, Embedder> embedders, String expected) {
+ try {
+ var expression = Expression.fromString(exp, new SimpleLinguistics(), embedders);
+ TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myText", DataType.STRING));
+ var tensorField = new Field("myTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+ adapter.setValue("myText", new StringFieldValue("input text"));
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ // Necessary to resolve output type
+ VerificationContext verificationContext = new VerificationContext(adapter);
+ assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass());
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(new StringFieldValue("input text"));
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("myTensor"));
+ assertEquals(Tensor.from(tensorType, expected),
+ ((TensorFieldValue)adapter.values.get("myTensor")).getTensor().get());
+ } catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
}
@SuppressWarnings("unchecked")
@Test
public void testArrayEmbed() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray", "[7,3,0,0]"));
+
TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'",
new SimpleLinguistics(),
- new MockEmbedder("myDocument.myTensorArray"));
+ embedders);
SimpleTestAdapter adapter = new SimpleTestAdapter();
adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
@@ -235,9 +265,11 @@ public class ScriptTestCase {
private static class MockEmbedder implements Embedder {
private final String expectedDestination;
+ private final String tensorString;
- public MockEmbedder(String expectedDestination) {
+ public MockEmbedder(String expectedDestination, String tensorString) {
this.expectedDestination = expectedDestination;
+ this.tensorString = tensorString;
}
@Override
@@ -248,9 +280,18 @@ public class ScriptTestCase {
@Override
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
assertEquals(expectedDestination, context.getDestination());
- return Tensor.from(tensorType, "[7,3,0,0]");
+ return Tensor.from(tensorType, tensorString);
}
}
+ private void assertThrows(Runnable r, String msg) {
+ try {
+ r.run();
+ fail();
+ } catch (IllegalStateException e) {
+ assertEquals(e.getMessage(), msg);
+ }
+ }
+
}
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java
index f6aa7e477a8..89170027c73 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java
@@ -19,7 +19,7 @@ public class DefaultFieldNameTestCase {
public void requireThatDefaultFieldNameIsAppliedWhenArgumentIsMissing() throws ParseException {
IndexingInput input = new IndexingInput("input");
InputExpression exp = (InputExpression)Expression.newInstance(new ScriptParserContext(new SimpleLinguistics(),
- Embedder.throwsOnUse)
+ Embedder.throwsOnUse.asMap())
.setInputStream(input)
.setDefaultFieldName("foo"));
assertEquals("foo", exp.getFieldName());
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java
index e333eea7001..7db026d43ee 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java
@@ -85,9 +85,9 @@ public class ExpressionTestCase {
private static void assertExpression(Class expectedClass, String str) throws ParseException {
Linguistics linguistics = new SimpleLinguistics();
- Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse);
+ Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse.asMap());
assertEquals(expectedClass, foo.getClass());
- Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse);
+ Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse.asMap());
assertEquals(foo.hashCode(), bar.hashCode());
assertEquals(foo, bar);
}