diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-17 08:48:45 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2023-12-17 08:48:45 +0100 |
commit | 745a8db7a8eaea7aa53736a26d64e97543900343 (patch) | |
tree | 3dc56fe5f2b0d0a300cd24470912f98c8842985f /indexinglanguage | |
parent | 56ff2f5e971a26d81cfe5cdbac65d856118820e4 (diff) |
Allow mapped 1d tensor for embed expressions
Diffstat (limited to 'indexinglanguage')
-rw-r--r-- | indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 29399a38fa9..1a9caaa5ca1 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -119,7 +119,7 @@ public class EmbedExpression extends Expression { "Don't know what tensor type to embed into"); targetType = toTargetTensor(context.getInputType(this, outputField)); if ( ! validTarget(targetType)) - throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " + + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," + "an array of dense 1d tensors, or a mixed 2d tensor"); context.setValueType(createdOutputType()); } @@ -134,14 +134,14 @@ public class EmbedExpression extends Expression { if ( ! ( dataType instanceof TensorDataType)) throw new IllegalArgumentException("Expected a tensor data type but got " + dataType); return ((TensorDataType)dataType).getTensorType(); - } private boolean validTarget(TensorType target) { - if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1) - return true; - if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1) + if (target.dimensions().size() == 1) //indexed or mapped 1d tensor return true; + if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 + && target.mappedSubtype().rank() == 1) + return true; //mixed mapped-indexed 2d tensor return false; } |