summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-03-16 13:39:46 +0100
committerLester Solbakken <lesters@oath.com>2021-03-16 13:39:46 +0100
commitd7456a4c3504ad84afa9f461322bcdcc79e8b357 (patch)
treef162ff2b125d815fec82bc0df645a3cefff5a35d /searchlib
parent73702b1c05deaaf08bcfed78c15494d2e53684a9 (diff)
Revert "Revert "Lesters/cell cast java""
This reverts commit d2c61030d6c62b8c4889d3471d2ee5f17bb14a5f.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj17
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java17
3 files changed, 34 insertions, 2 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index d412f408350..9e958dd4d4c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -897,6 +897,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()",
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
"public final com.yahoo.tensor.TensorType tensorType(java.util.List)",
@@ -1046,6 +1047,7 @@
"public static final int XW_PLUS_B",
"public static final int ARGMAX",
"public static final int ARGMIN",
+ "public static final int CELL_CAST",
"public static final int AVG",
"public static final int COUNT",
"public static final int MAX",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 36b1f9627bb..d33e9ccff7f 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -141,6 +141,7 @@ TOKEN :
<XW_PLUS_B: "xw_plus_b"> |
<ARGMAX: "argmax"> |
<ARGMIN: "argmin"> |
+ <CELL_CAST: "cell_cast"> |
<AVG: "avg" > |
<COUNT: "count"> |
@@ -380,7 +381,8 @@ TensorFunctionNode tensorFunction() :
tensorExpression = tensorSoftmax() |
tensorExpression = tensorXwPlusB() |
tensorExpression = tensorArgmax() |
- tensorExpression = tensorArgmin()
+ tensorExpression = tensorArgmin() |
+ tensorExpression = tensorCellCast()
)
{ return tensorExpression; }
}
@@ -597,6 +599,16 @@ TensorFunctionNode tensorArgmin() :
{ return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimensions)); }
}
+TensorFunctionNode tensorCellCast() :
+{
+ ExpressionNode tensor;
+ String valueType;
+}
+{
+ <CELL_CAST> <LBRACE> tensor = expression() <COMMA> valueType = identifier() <RBRACE>
+ { return new TensorFunctionNode(new CellCast(TensorFunctionNode.wrap(tensor), TensorType.Value.fromId(valueType)));}
+}
+
LambdaFunctionNode lambdaFunction() :
{
List<String> variables;
@@ -667,7 +679,7 @@ String tensorFunctionName() :
( <MAP> { return token.image; } ) |
( <REDUCE> { return token.image; } ) |
( <JOIN> { return token.image; } ) |
- ( <MERGE> { return token.image; } ) |
+ ( <MERGE> { return token.image; } ) |
( <RENAME> { return token.image; } ) |
( <CONCAT> { return token.image; } ) |
( <TENSOR> { return token.image; } ) |
@@ -681,6 +693,7 @@ String tensorFunctionName() :
( <XW_PLUS_B> { return token.image; } ) |
( <ARGMAX> { return token.image; } ) |
( <ARGMIN> { return token.image; } ) |
+ ( <CELL_CAST> { return token.image; } )
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 123fa5ac43b..fae5a7a093c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -394,6 +394,23 @@ public class EvaluationTestCase {
}
@Test
+ public void testCellTypeCasting() {
+ EvaluationTester tester = new EvaluationTester();
+
+ tester.assertEvaluates("tensor<float>(x[3]):[1.0, 2.0, 3.0]",
+ "cell_cast(tensor0, float)",
+ "tensor<double>(x[3]):[1, 2, 3]");
+ tester.assertEvaluates("tensor<float>():{1}",
+ "cell_cast(tensor0{x:1}, float)",
+ "tensor<double>(x{}):{1:1, 2:2, 3:3}");
+ tester.assertEvaluates("tensor<float>(x[2]):[3,8]",
+ "cell_cast(tensor0 * tensor1, float)",
+ "tensor<float>(x[2]):[1,2]",
+ "tensor<double>(x[2]):[3,4]");
+ }
+
+
+ @Test
public void testMixedTensorType() throws ParseException {
String expected = "tensor(x[1],y{},z[2]):{{x:0,y:a,z:0}:4.0,{x:0,y:a,z:1}:5.0,{x:0,y:b,z:0}:7.0,{x:0,y:b,z:1}:8.0}";
String a = "tensor(x[1],y{}):{ {x:0,y:a}:1, {x:0,y:b}:2 }";