diff options
4 files changed, 33 insertions, 11 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index eb0ef1bdb08..5889f4844af 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -894,9 +894,9 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", - "public final com.yahoo.tensor.TensorType tensorType()", + "public final com.yahoo.tensor.TensorType tensorType(java.util.List)", "public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()", - "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)", + "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder, java.util.List)", "public final java.lang.String tensorFunctionName()", "public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()", "public final com.yahoo.searchlib.rankingexpression.rule.Function binaryFunctionName()", @@ -910,11 +910,11 @@ "public final java.util.List tagCommaLeadingList()", "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()", - "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType, java.util.List)", "public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)", - "public final com.yahoo.tensor.functions.DynamicTensor mixedTensorValueBody(com.yahoo.tensor.TensorType)", - "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)", - "public final void mixedBlock(com.yahoo.tensor.TensorType, java.util.Map)", + "public final com.yahoo.tensor.functions.DynamicTensor mixedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)", + "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)", + "public final void mixedBlock(com.yahoo.tensor.TensorType, java.util.List, java.util.Map)", "public final java.util.List indexedTensorCells()", "public final void indexedTensorCellSubspaceList(java.util.List)", "public final void indexedTensorCellSubspace(java.util.List)", @@ -1621,8 +1621,8 @@ "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", "public static java.util.Map wrapScalars(java.util.Map)", - "public static void wrapScalarBlock(com.yahoo.tensor.TensorType, java.lang.String, java.util.List, java.util.Map)", - "public static java.util.List wrapScalars(java.util.List)", + "public static void wrapScalarBlock(com.yahoo.tensor.TensorType, java.util.List, java.lang.String, java.util.List, java.util.Map)", + "public static java.util.List wrapScalars(com.yahoo.tensor.TensorType, java.util.List, java.util.List)", "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" ], "fields": [] diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index a248fa6dd45..6200515462b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -103,6 +103,9 @@ public class TensorFunctionNode extends CompositeNode { List<String> denseDimensionOrder = new ArrayList<>(dimensionOrder); denseDimensionOrder.retainAll(denseSubtype.dimensionNames()); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype, denseDimensionOrder); + if (indexes.size() != nodes.size()) + throw new IllegalArgumentException("At '" + mappedDimensionLabel + "': Need " + indexes.size() + + " values to fill a dense subspace of " + type + " but got " + nodes.size()); for (ExpressionNode node : nodes) { indexes.next(); @@ -125,12 +128,15 @@ public class TensorFunctionNode extends CompositeNode { List<String> dimensionOrder, List<ExpressionNode> nodes) { IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type, dimensionOrder); - List<ScalarFunction<Reference>> wrapped = new ArrayList<>(); + if (indexes.size() != nodes.size()) + throw new IllegalArgumentException("Need " + indexes.size() + " values to fill " + type + " but got " + nodes.size()); + + List<ScalarFunction<Reference>> wrapped = new ArrayList<>(nodes.size()); while (indexes.hasNext()) { indexes.next(); wrapped.add(wrapScalar(nodes.get((int)indexes.toSourceValueIndex()))); } - return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); + return wrapped; } public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) { diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 22d2abd4aef..1bfb13cff6f 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -880,7 +880,7 @@ DynamicTensor indexedTensorValueBody(TensorType type, List dimensionOrder) : } { cells = indexedTensorCells() - { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells, type, dimensionOrder)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(type, dimensionOrder, cells)); } } void mixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index fa65ce0408b..21127607107 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -14,6 +14,7 @@ import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Tests expression evaluation @@ -431,6 +432,21 @@ public class EvaluationTestCase { " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}", "tensor(key{},x[3],y[2]):{key1:[[one,a_quarter],[one_half,one_half],[a_quarter,one]]," + " key2:[[1,4],[2,5],[3,6]]}"); + + try { + new RankingExpression("tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter]}"); + fail("Expected exception"); + } + catch (Exception e) { + assertEquals("At 'b': Need 2 values to fill a dense subspace of tensor(x{},y[2]) but got 1", e.getMessage()); + } + try { + new RankingExpression("tensor(x[2],y[3]):[[1,2,3,4],[4,5,6]]"); + fail("Expected exception"); + } + catch (Exception e) { + assertEquals("Need 6 values to fill tensor(x[2],y[3]) but got 7", e.getMessage()); + } } @Test |