summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--searchlib/abi-spec.json16
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java10
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java16
4 files changed, 33 insertions, 11 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index eb0ef1bdb08..5889f4844af 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -894,9 +894,9 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
- "public final com.yahoo.tensor.TensorType tensorType()",
+ "public final com.yahoo.tensor.TensorType tensorType(java.util.List)",
"public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()",
- "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)",
+ "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder, java.util.List)",
"public final java.lang.String tensorFunctionName()",
"public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()",
"public final com.yahoo.searchlib.rankingexpression.rule.Function binaryFunctionName()",
@@ -910,11 +910,11 @@
"public final java.util.List tagCommaLeadingList()",
"public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()",
- "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType, java.util.List)",
"public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)",
- "public final com.yahoo.tensor.functions.DynamicTensor mixedTensorValueBody(com.yahoo.tensor.TensorType)",
- "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)",
- "public final void mixedBlock(com.yahoo.tensor.TensorType, java.util.Map)",
+ "public final com.yahoo.tensor.functions.DynamicTensor mixedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)",
+ "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)",
+ "public final void mixedBlock(com.yahoo.tensor.TensorType, java.util.List, java.util.Map)",
"public final java.util.List indexedTensorCells()",
"public final void indexedTensorCellSubspaceList(java.util.List)",
"public final void indexedTensorCellSubspace(java.util.List)",
@@ -1621,8 +1621,8 @@
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
"public static java.util.Map wrapScalars(java.util.Map)",
- "public static void wrapScalarBlock(com.yahoo.tensor.TensorType, java.lang.String, java.util.List, java.util.Map)",
- "public static java.util.List wrapScalars(java.util.List)",
+ "public static void wrapScalarBlock(com.yahoo.tensor.TensorType, java.util.List, java.lang.String, java.util.List, java.util.Map)",
+ "public static java.util.List wrapScalars(com.yahoo.tensor.TensorType, java.util.List, java.util.List)",
"public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
],
"fields": []
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index a248fa6dd45..6200515462b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -103,6 +103,9 @@ public class TensorFunctionNode extends CompositeNode {
List<String> denseDimensionOrder = new ArrayList<>(dimensionOrder);
denseDimensionOrder.retainAll(denseSubtype.dimensionNames());
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype, denseDimensionOrder);
+ if (indexes.size() != nodes.size())
+ throw new IllegalArgumentException("At '" + mappedDimensionLabel + "': Need " + indexes.size() +
+ " values to fill a dense subspace of " + type + " but got " + nodes.size());
for (ExpressionNode node : nodes) {
indexes.next();
@@ -125,12 +128,15 @@ public class TensorFunctionNode extends CompositeNode {
List<String> dimensionOrder,
List<ExpressionNode> nodes) {
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type, dimensionOrder);
- List<ScalarFunction<Reference>> wrapped = new ArrayList<>();
+ if (indexes.size() != nodes.size())
+ throw new IllegalArgumentException("Need " + indexes.size() + " values to fill " + type + " but got " + nodes.size());
+
+ List<ScalarFunction<Reference>> wrapped = new ArrayList<>(nodes.size());
while (indexes.hasNext()) {
indexes.next();
wrapped.add(wrapScalar(nodes.get((int)indexes.toSourceValueIndex())));
}
- return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
+ return wrapped;
}
public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) {
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 22d2abd4aef..1bfb13cff6f 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -880,7 +880,7 @@ DynamicTensor indexedTensorValueBody(TensorType type, List dimensionOrder) :
}
{
cells = indexedTensorCells()
- { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells, type, dimensionOrder)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(type, dimensionOrder, cells)); }
}
void mixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index fa65ce0408b..21127607107 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -14,6 +14,7 @@ import com.yahoo.tensor.TensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
/**
* Tests expression evaluation
@@ -431,6 +432,21 @@ public class EvaluationTestCase {
" key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}",
"tensor(key{},x[3],y[2]):{key1:[[one,a_quarter],[one_half,one_half],[a_quarter,one]]," +
" key2:[[1,4],[2,5],[3,6]]}");
+
+ try {
+ new RankingExpression("tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter]}");
+ fail("Expected exception");
+ }
+ catch (Exception e) {
+ assertEquals("At 'b': Need 2 values to fill a dense subspace of tensor(x{},y[2]) but got 1", e.getMessage());
+ }
+ try {
+ new RankingExpression("tensor(x[2],y[3]):[[1,2,3,4],[4,5,6]]");
+ fail("Expected exception");
+ }
+ catch (Exception e) {
+ assertEquals("Need 6 values to fill tensor(x[2],y[3]) but got 7", e.getMessage());
+ }
}
@Test