summaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-02-02 12:28:53 +0100
committerJon Bratseth <bratseth@vespa.ai>2024-02-02 12:28:53 +0100
commit1a25431ab58c752c7fc26dd8223bf1ba1079b24a (patch)
tree954d7e2f3e43bb0636a6af7a93195a84e41e609b /indexinglanguage
parent2191193c6e107eb68611ddb106e5f572bea32903 (diff)
Support embedding into rank 3 tensors
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java73
-rw-r--r--indexinglanguage/src/main/javacc/IndexingParser.jj12
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java250
3 files changed, 270 insertions, 65 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 7c5e8912e49..5daf74a9723 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -27,6 +27,7 @@ public class EmbedExpression extends Expression {
private final Embedder embedder;
private final String embedderId;
+ private final List<String> embedderArguments;
/** The destination the embedding will be written to on the form [schema name].[field name] */
private String destination;
@@ -34,22 +35,23 @@ public class EmbedExpression extends Expression {
/** The target type we are embedding into. */
private TensorType targetType;
- public EmbedExpression(Map<String, Embedder> embedders, String embedderId) {
+ public EmbedExpression(Map<String, Embedder> embedders, String embedderId, List<String> embedderArguments) {
super(null);
this.embedderId = embedderId;
+ this.embedderArguments = List.copyOf(embedderArguments);
- boolean embedderIdProvided = embedderId != null && embedderId.length() > 0;
+ boolean embedderIdProvided = embedderId != null && !embedderId.isEmpty();
if (embedders.size() == 0) {
throw new IllegalStateException("No embedders provided"); // should never happen
}
+ else if (embedders.size() == 1 && ! embedderIdProvided) {
+ this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
+ }
else if (embedders.size() > 1 && ! embedderIdProvided) {
this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " +
"Valid embedders are " + validEmbedders(embedders));
}
- else if (embedders.size() == 1 && ! embedderIdProvided) {
- this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
- }
else if ( ! embedders.containsKey(embedderId)) {
this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " +
"Valid embedders are " + validEmbedders(embedders));
@@ -91,17 +93,51 @@ public class EmbedExpression extends Expression {
private Tensor embedArrayValue(ExecutionContext context) {
var input = (Array<StringFieldValue>)context.getValue();
var builder = Tensor.Builder.of(targetType);
+ if (targetType.rank() == 2)
+ embedArrayValueToRank2Tensor(input, builder, context);
+ else
+ embedArrayValueToRank3Tensor(input, builder, context);
+ return builder.build();
+ }
+
+ private void embedArrayValueToRank2Tensor(Array<StringFieldValue> input,
+ Tensor.Builder builder,
+ ExecutionContext context) {
+ String mappedDimension = targetType.mappedSubtype().dimensions().get(0).name();
+ String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name();
for (int i = 0; i < input.size(); i++) {
Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context);
for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
Tensor.Cell cell = cells.next();
builder.cell()
- .label(targetType.mappedSubtype().dimensions().get(0).name(), i)
- .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().numericLabel(0))
+ .label(mappedDimension, i)
+ .label(indexedDimension, cell.getKey().numericLabel(0))
+ .value(cell.getValue());
+ }
+ }
+ }
+
+ private void embedArrayValueToRank3Tensor(Array<StringFieldValue> input,
+ Tensor.Builder builder,
+ ExecutionContext context) {
+ String outerMappedDimension = embedderArguments.get(0);
+ String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
+ String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name();
+ long indexedDimensionSize = targetType.indexedSubtype().dimensions().get(0).size().get();
+ var innerType = new TensorType.Builder().mapped(innerMappedDimension).indexed(indexedDimension,indexedDimensionSize).build();
+ int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
+ int indexedDimensionIndex = innerType.indexOfDimensionAsInt(indexedDimension);
+ for (int i = 0; i < input.size(); i++) {
+ Tensor tensor = embed(input.get(i).getString(), innerType, context);
+ for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
+ Tensor.Cell cell = cells.next();
+ builder.cell()
+ .label(outerMappedDimension, i)
+ .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex))
+ .label(indexedDimension, cell.getKey().numericLabel(indexedDimensionIndex))
.value(cell.getValue());
}
}
- return builder.build();
}
private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
@@ -120,7 +156,17 @@ public class EmbedExpression extends Expression {
targetType = toTargetTensor(context.getInputType(this, outputField));
if ( ! validTarget(targetType))
throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," +
- "an array of dense 1d tensors, or a mixed 2d tensor");
+ "an array of dense 1d tensors, or a mixed 2d or 3d tensor");
+ if (targetType.rank() == 3) {
+ if (embedderArguments.size() != 1)
+ throw new VerificationException(this, "When the embedding target field is a 3d tensor " +
+ "the name of the tensor dimension that corresponds to the input array elements must " +
+ "be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...");
+ if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0)))
+ throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " +
+ "is not a sparse dimension of the target type " + targetType);
+ }
+
context.setValueType(createdOutputType());
}
@@ -137,11 +183,12 @@ public class EmbedExpression extends Expression {
}
private boolean validTarget(TensorType target) {
- if (target.dimensions().size() == 1) //indexed or mapped 1d tensor
+ if (target.rank() == 1) // indexed or mapped 1d tensor
return true;
- if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1
- && target.mappedSubtype().rank() == 1)
- return true; //mixed mapped-indexed 2d tensor
+ if (target.rank() == 2 && target.indexedSubtype().rank() == 1)
+ return true; // mixed 2d tensor
+ if (target.rank() == 3 && target.indexedSubtype().rank() == 1)
+ return true; // mixed 3d tensor
return false;
}
diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj
index 42bbd26cee6..a3b4039408a 100644
--- a/indexinglanguage/src/main/javacc/IndexingParser.jj
+++ b/indexinglanguage/src/main/javacc/IndexingParser.jj
@@ -37,7 +37,6 @@ import com.yahoo.language.Linguistics;
/**
* @author Simon Thoresen Hult
- * @version $Id$
*/
public class IndexingParser {
@@ -386,11 +385,16 @@ Expression echoExp() : { }
Expression embedExp() :
{
- String val = "";
+ String embedderId = "";
+ String embedderArgument;
+ List<String> embedderArguments = new ArrayList<String>();
}
{
- ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] )
- { return new EmbedExpression(embedders, val); }
+ (
+ <EMBED> [ LOOKAHEAD(2) embedderId = identifier() ]
+ ( LOOKAHEAD(2) embedderArgument = identifier() { embedderArguments.add(embedderArgument); } )*
+ )
+ { return new EmbedExpression(embedders, embedderId, embedderArguments); }
}
Expression exactExp() : { }
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 6206c2efe7a..7fe55b738df 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -181,7 +181,7 @@ public class ScriptTestCase {
Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
Map<String, Embedder> embedder = Map.of(
- "emb1", new MockEmbedder("myDocument.myTensor")
+ "emb1", new MockIndexedEmbedder("myDocument.myTensor")
);
testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder,
"input text", "[105, 110, 112, 117]");
@@ -193,8 +193,8 @@ public class ScriptTestCase {
null, null);
Map<String, Embedder> embedders = Map.of(
- "emb1", new MockEmbedder("myDocument.myTensor"),
- "emb2", new MockEmbedder("myDocument.myTensor", 1)
+ "emb1", new MockIndexedEmbedder("myDocument.myTensor"),
+ "emb2", new MockIndexedEmbedder("myDocument.myTensor", 1)
);
testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders,
"my input", "[109.0, 121.0, 32.0, 105.0]");
@@ -243,7 +243,7 @@ public class ScriptTestCase {
@SuppressWarnings("unchecked")
@Test
public void testArrayEmbed() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray"));
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.myTensorArray"));
TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'",
@@ -277,7 +277,7 @@ public class ScriptTestCase {
@Test
public void testArrayEmbedWithConcatenation() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'",
@@ -314,9 +314,10 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ /** Multiple paragraphs */
@Test
- public void testArrayEmbedToMixedTensor() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+ public void testArrayEmbedTo2dMixedTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'",
@@ -348,17 +349,125 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("sec"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter)));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ // The two "passages" are [first, sec], the middle (d=1) token encodes those letters
+ assertEquals(Tensor.from(tensorType,
+ """
+ {
+ {passage:0, token:0, d:0}: 101,
+ {passage:0, token:0, d:1}: 102,
+ {passage:0, token:0, d:2}: 103,
+ {passage:0, token:1, d:0}: 104,
+ {passage:0, token:1, d:1}: 105,
+ {passage:0, token:1, d:2}: 106,
+ {passage:0, token:2, d:0}: 113,
+ {passage:0, token:2, d:1}: 114,
+ {passage:0, token:2, d:2}: 115,
+ {passage:0, token:3, d:0}: 114,
+ {passage:0, token:3, d:1}: 115,
+ {passage:0, token:3, d:2}: 116,
+ {passage:0, token:4, d:0}: 115,
+ {passage:0, token:4, d:1}: 116,
+ {passage:0, token:4, d:2}: 117,
+ {passage:1, token:0, d:0}: 114,
+ {passage:1, token:0, d:1}: 115,
+ {passage:1, token:0, d:2}: 116,
+ {passage:1, token:1, d:0}: 100,
+ {passage:1, token:1, d:1}: 101,
+ {passage:1, token:1, d:2}: 102,
+ {passage:1, token:2, d:0}: 98,
+ {passage:1, token:2, d:1}: 99,
+ {passage:1, token:2, d:2}: 100
+ }
+ """),
+ sparseTensor.getTensor().get());
+ }
+
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType)));
+
+ try {
+ expression.verify(new VerificationContext(adapter));
+ fail("Expected exception");
+ }
+ catch (VerificationException e) {
+ assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...",
+ e.getMessage());
+ }
+ }
+
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType)));
+
+ try {
+ expression.verify(new VerificationContext(adapter));
+ fail("Expected exception");
+ }
+ catch (VerificationException e) {
+ assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})",
+ e.getMessage());
+ }
+ }
+
@SuppressWarnings("OptionalGetWithoutIsPresent")
@Test
public void testEmbedToSparseTensor() throws ParseException {
-
- Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true);
+ Embedder mappedEmbedder = new MockMappedEmbedder("myDocument.mySparseTensor", 0);
Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder);
TensorType tensorType = TensorType.fromSpec("tensor(t{})");
var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'",
- new SimpleLinguistics(),
- embedders);
+ new SimpleLinguistics(),
+ embedders);
SimpleTestAdapter adapter = new SimpleTestAdapter();
adapter.createField(new Field("text", DataType.STRING));
@@ -383,30 +492,23 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
- // An embedder which returns the char value of each letter in the input. */
- private static class MockEmbedder implements Embedder {
-
- private final String expectedDestination;
- private final int addition;
-
- private final boolean mappedTensor;
-
-
- public MockEmbedder(String expectedDestination) {
- this(expectedDestination, 0, false);
- }
- public MockEmbedder(String expectedDestination, boolean mapped) {
- this(expectedDestination, 0,mapped);
+ private void assertThrows(Runnable r, String msg) {
+ try {
+ r.run();
+ fail();
+ } catch (IllegalStateException e) {
+ assertEquals(e.getMessage(), msg);
}
+ }
- public MockEmbedder(String expectedDestination,int addition) {
- this(expectedDestination, addition,false);
- }
+ private static abstract class MockEmbedder implements Embedder {
- public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) {
+ final String expectedDestination;
+ final int addition;
+
+ public MockEmbedder(String expectedDestination, int addition) {
this.expectedDestination = expectedDestination;
this.addition = addition;
- this.mappedTensor = mappedTensor;
}
@Override
@@ -414,32 +516,84 @@ public class ScriptTestCase {
return null;
}
+ void verifyDestination(Embedder.Context context) {
+ assertEquals(expectedDestination, context.getDestination());
+ }
+
+ }
+
+ /** An embedder which returns the char value of each letter in the input as a 1d indexed tensor. */
+ private static class MockIndexedEmbedder extends MockEmbedder {
+
+ public MockIndexedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockIndexedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
@Override
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
- assertEquals(expectedDestination, context.getDestination());
+ verifyDestination(context);
var b = Tensor.Builder.of(tensorType);
- if (mappedTensor) {
- for(int i = 0; i < text.length(); i++) {
- var value = text.charAt(i) + addition;
- b.cell().
- label(tensorType.dimensions().get(0).name(), text.charAt(i))
- .value(value);
- }
- } else {
- for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
- b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
+ b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ return b.build();
+ }
- }
+ }
+
+ /** An embedder which returns the char value of each letter in the input as a 1d mapped tensor. */
+ private static class MockMappedEmbedder extends MockEmbedder {
+
+ public MockMappedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockMappedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ verifyDestination(context);
+ var b = Tensor.Builder.of(tensorType);
+ for (int i = 0; i < text.length(); i++)
+ b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition);
return b.build();
}
+
}
- private void assertThrows(Runnable r, String msg) {
- try {
- r.run();
- fail();
- } catch (IllegalStateException e) {
- assertEquals(e.getMessage(), msg);
+ /**
+ * An embedder which returns the char value of each letter in the input as a 2d mixed tensor where each input
+ * char becomes an indexed dimension containing input-1, input, input+1.
+ */
+ private static class MockMixedEmbedder extends MockEmbedder {
+
+ public MockMixedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockMixedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ verifyDestination(context);
+ var b = Tensor.Builder.of(tensorType);
+ String mappedDimension = tensorType.mappedSubtype().dimensions().get(0).name();
+ String indexedDimension = tensorType.indexedSubtype().dimensions().get(0).name();
+ for (int i = 0; i < text.length(); i++) {
+ for (int j = 0; j < 3; j++) {
+ b.cell().label(mappedDimension, i)
+ .label(indexedDimension, j)
+ .value(text.charAt(i) + addition + j - 1);
+ }
+ }
+ return b.build();
}
}