aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-03-15 23:17:33 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2024-03-15 23:17:33 +0100
commit8791830cb01bd3cc6370472cd7918cb0c7e24e34 (patch)
tree27be788de5987fec10b428e1b61bd3eab3c22a69
parent808d67962afdc58090790915fdebac29963b51e4 (diff)
Attempt of supporting mapping array to mapped 2d tensor for sparse models
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java45
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java67
2 files changed, 109 insertions, 3 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 1cfe8532c92..5d5410c2ef0 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -17,7 +17,7 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
+
/**
* Embeds a string in a tensor space using the configured Embedder component
@@ -95,7 +95,12 @@ public class EmbedExpression extends Expression {
var input = (Array<StringFieldValue>)context.getValue();
var builder = Tensor.Builder.of(targetType);
if (targetType.rank() == 2)
- embedArrayValueToRank2Tensor(input, builder, context);
+ if (targetType.indexedSubtype().rank() == 1)
+ embedArrayValueToRank2Tensor(input, builder, context);
+ else if(targetType.mappedSubtype().rank() == 2)
+ embedArrayValueToRank2MappedTensor(input, builder, context);
+ else
+ throw new IllegalArgumentException("Embedding an array into " + targetType + " is not supported");
else
embedArrayValueToRank3Tensor(input, builder, context);
return builder.build();
@@ -141,6 +146,27 @@ public class EmbedExpression extends Expression {
}
}
+ private void embedArrayValueToRank2MappedTensor(Array<StringFieldValue> input,
+ Tensor.Builder builder,
+ ExecutionContext context) {
+ String outerMappedDimension = embedderArguments.get(0);
+ String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
+
+ var innerType = new TensorType.Builder(targetType.valueType()).mapped(innerMappedDimension).build();
+ int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
+
+ for (int i = 0; i < input.size(); i++) {
+ Tensor tensor = embed(input.get(i).getString(), innerType, context);
+ for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
+ Tensor.Cell cell = cells.next();
+ builder.cell()
+ .label(outerMappedDimension, i)
+ .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex))
+ .value(cell.getValue());
+ }
+ }
+ }
+
private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
return embedder.embed(input,
new Embedder.Context(destination).setLanguage(context.getLanguage()).setEmbedderId(embedderId),
@@ -156,8 +182,19 @@ public class EmbedExpression extends Expression {
"Don't know what tensor type to embed into");
targetType = toTargetTensor(context.getInputType(this, outputField));
if ( ! validTarget(targetType))
- throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," +
+ throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, " +
"an array of dense 1d tensors, or a mixed 2d or 3d tensor");
+ if (targetType.rank() == 2 && targetType.mappedSubtype().rank() == 2) {
+ if (embedderArguments.size() != 1)
+ throw new VerificationException(this, "When the embedding target field is a 2d mapped tensor " +
+ "the name of the tensor dimension that corresponds to the input array elements must " +
+ "be given as a second argument to embed, e.g: ... | embed splade paragraph | ...");
+ if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) {
+ throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " +
+ "is not a sparse dimension of the target type " + targetType);
+
+ }
+ }
if (targetType.rank() == 3) {
if (embedderArguments.size() != 1)
throw new VerificationException(this, "When the embedding target field is a 3d tensor " +
@@ -188,6 +225,8 @@ public class EmbedExpression extends Expression {
return true;
if (target.rank() == 2 && target.indexedSubtype().rank() == 1)
return true; // mixed 2d tensor
+ if(target.rank() == 2 && target.mappedSubtype().rank() == 2)
+ return true; // mapped 2d tensor
if (target.rank() == 3 && target.indexedSubtype().rank() == 1)
return true; // mixed 3d tensor
return false;
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 06ff6bc85b1..4e1eae2ed46 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -493,6 +493,73 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ /** Multiple paragraphs with sparse encoding (splade style) */
+ @Test
+ public void testArrayEmbedTo2dMappedTensor_wrongDimensionArgument() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMappedEmbedder("myDocument.my2DSparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})");
+ var expression = Expression.fromString("input myTextArray | embed emb1 doh | attribute 'my2DSparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ adapter.createField(new Field("my2DSparseTensor", new TensorDataType(tensorType)));
+
+ try {
+ expression.verify(new VerificationContext(adapter));
+ fail("Expected exception");
+ }
+ catch (VerificationException e) {
+ assertEquals("The dimension 'doh' given to embed is not a sparse dimension of the target type tensor(passage{},token{})",
+ e.getMessage());
+ }
+ }
+
+ /** Multiple paragraphs with sparse encoding (splade style) */
+ @Test
+ @SuppressWarnings("OptionalGetWithoutIsPresent")
+ public void testArrayEmbedTo2MappedTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMappedEmbedder("myDocument.my2DSparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{})");
+ var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'my2DSparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+ assertEquals("input myTextArray | embed emb1 passage | attribute my2DSparseTensor", expression.toString());
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ var tensorField = new Field("my2DSparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("abc"));
+ array.add(new StringFieldValue("cde"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter)));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("my2DSparseTensor"));
+ var sparse2DTensor = (TensorFieldValue)adapter.values.get("my2DSparseTensor");
+ assertEquals(Tensor.from(
+ tensorType,
+ "tensor(passage{},token{}):" +
+ "{{passage:0,token:97}:97.0, " +
+ "{passage:0,token:98}:98.0, " +
+ "{passage:0,token:99}:99.0, " +
+ "{passage:1,token:100}:100.0, " +
+ "{passage:1,token:101}:101.0, " +
+ "{passage:1,token:99}:99.0}"),
+ sparse2DTensor.getTensor().get());
+ }
+
+
private void assertThrows(Runnable r, String msg) {
try {
r.run();