aboutsummaryrefslogtreecommitdiffstats
path: root/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
diff options
context:
space:
mode:
Diffstat (limited to 'indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java45
1 files changed, 42 insertions, 3 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 1cfe8532c92..5d5410c2ef0 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -17,7 +17,7 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.Objects;
+
/**
* Embeds a string in a tensor space using the configured Embedder component
@@ -95,7 +95,12 @@ public class EmbedExpression extends Expression {
var input = (Array<StringFieldValue>)context.getValue();
var builder = Tensor.Builder.of(targetType);
if (targetType.rank() == 2)
- embedArrayValueToRank2Tensor(input, builder, context);
+ if (targetType.indexedSubtype().rank() == 1)
+ embedArrayValueToRank2Tensor(input, builder, context);
+ else if(targetType.mappedSubtype().rank() == 2)
+ embedArrayValueToRank2MappedTensor(input, builder, context);
+ else
+ throw new IllegalArgumentException("Embedding an array into " + targetType + " is not supported");
else
embedArrayValueToRank3Tensor(input, builder, context);
return builder.build();
@@ -141,6 +146,27 @@ public class EmbedExpression extends Expression {
}
}
+ private void embedArrayValueToRank2MappedTensor(Array<StringFieldValue> input,
+ Tensor.Builder builder,
+ ExecutionContext context) {
+ String outerMappedDimension = embedderArguments.get(0);
+ String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
+
+ var innerType = new TensorType.Builder(targetType.valueType()).mapped(innerMappedDimension).build();
+ int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
+
+ for (int i = 0; i < input.size(); i++) {
+ Tensor tensor = embed(input.get(i).getString(), innerType, context);
+ for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
+ Tensor.Cell cell = cells.next();
+ builder.cell()
+ .label(outerMappedDimension, i)
+ .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex))
+ .value(cell.getValue());
+ }
+ }
+ }
+
private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
return embedder.embed(input,
new Embedder.Context(destination).setLanguage(context.getLanguage()).setEmbedderId(embedderId),
@@ -156,8 +182,19 @@ public class EmbedExpression extends Expression {
"Don't know what tensor type to embed into");
targetType = toTargetTensor(context.getInputType(this, outputField));
if ( ! validTarget(targetType))
- throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," +
+ throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, " +
"an array of dense 1d tensors, or a mixed 2d or 3d tensor");
+ if (targetType.rank() == 2 && targetType.mappedSubtype().rank() == 2) {
+ if (embedderArguments.size() != 1)
+ throw new VerificationException(this, "When the embedding target field is a 2d mapped tensor " +
+ "the name of the tensor dimension that corresponds to the input array elements must " +
+ "be given as a second argument to embed, e.g: ... | embed splade paragraph | ...");
+ if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) {
+ throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " +
+ "is not a sparse dimension of the target type " + targetType);
+
+ }
+ }
if (targetType.rank() == 3) {
if (embedderArguments.size() != 1)
throw new VerificationException(this, "When the embedding target field is a 3d tensor " +
@@ -188,6 +225,8 @@ public class EmbedExpression extends Expression {
return true;
if (target.rank() == 2 && target.indexedSubtype().rank() == 1)
return true; // mixed 2d tensor
+ if(target.rank() == 2 && target.mappedSubtype().rank() == 2)
+ return true; // mapped 2d tensor
if (target.rank() == 3 && target.indexedSubtype().rank() == 1)
return true; // mixed 3d tensor
return false;