summaryrefslogtreecommitdiffstats
path: root/indexinglanguage
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2023-12-17 08:48:45 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2023-12-17 08:48:45 +0100
commit745a8db7a8eaea7aa53736a26d64e97543900343 (patch)
tree3dc56fe5f2b0d0a300cd24470912f98c8842985f /indexinglanguage
parent56ff2f5e971a26d81cfe5cdbac65d856118820e4 (diff)
Allow mapped 1d tensor for embed expressions
Diffstat (limited to 'indexinglanguage')
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java10
1 files changed, 5 insertions, 5 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 29399a38fa9..1a9caaa5ca1 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -119,7 +119,7 @@ public class EmbedExpression extends Expression {
"Don't know what tensor type to embed into");
targetType = toTargetTensor(context.getInputType(this, outputField));
if ( ! validTarget(targetType))
- throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " +
+ throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," +
"an array of dense 1d tensors, or a mixed 2d tensor");
context.setValueType(createdOutputType());
}
@@ -134,14 +134,14 @@ public class EmbedExpression extends Expression {
if ( ! ( dataType instanceof TensorDataType))
throw new IllegalArgumentException("Expected a tensor data type but got " + dataType);
return ((TensorDataType)dataType).getTensorType();
-
}
private boolean validTarget(TensorType target) {
- if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1)
- return true;
- if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1)
+ if (target.dimensions().size() == 1) //indexed or mapped 1d tensor
return true;
+ if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1
+ && target.mappedSubtype().rank() == 1)
+ return true; //mixed mapped-indexed 2d tensor
return false;
}