diff options
author | Lester Solbakken <lesters@oath.com> | 2021-10-06 10:52:56 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-10-06 10:52:56 +0200 |
commit | bcd003d3253a5e51c19149dcc8fa44e8fd526adb (patch) | |
tree | 8c8505ae2a075996a4724da106337262398ad72e /searchlib/src | |
parent | 4de0026c1065403d028d7157abb571830603e6c9 (diff) |
Add non-primitive tensor expand function
Diffstat (limited to 'searchlib/src')
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 15 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java | 16 |
2 files changed, 29 insertions, 2 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 7bfbfd6c005..88eb0feeb73 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -144,6 +144,7 @@ TOKEN : <ARGMAX: "argmax"> | <ARGMIN: "argmin"> | <CELL_CAST: "cell_cast"> | + <EXPAND: "expand"> | <AVG: "avg" > | <COUNT: "count"> | @@ -384,7 +385,8 @@ TensorFunctionNode tensorFunction() : tensorExpression = tensorXwPlusB() | tensorExpression = tensorArgmax() | tensorExpression = tensorArgmin() | - tensorExpression = tensorCellCast() + tensorExpression = tensorCellCast() | + tensorExpression = tensorExpand() ) { return tensorExpression; } } @@ -581,6 +583,16 @@ TensorFunctionNode tensorXwPlusB() : dimension)); } } +TensorFunctionNode tensorExpand() : +{ + ExpressionNode argument; + String dimension; +} +{ + <EXPAND> <LBRACE> argument = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Expand(TensorFunctionNode.wrap(argument), dimension)); } +} + TensorFunctionNode tensorArgmax() : { ExpressionNode tensor; @@ -696,6 +708,7 @@ String tensorFunctionName() : ( <ARGMAX> { return token.image; } ) | ( <ARGMIN> { return token.image; } ) | ( <CELL_CAST> { return token.image; } ) | + ( <EXPAND> { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 246dbcb2b1e..ed8a15ad989 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -498,7 +498,6 @@ public class EvaluationTestCase { "tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", "tensor(d0[4]):[3,2,-1,1]"); - } @Test @@ -725,6 +724,21 @@ public class EvaluationTestCase { tester.assertEvaluates("tensor(d0[1], d1[3]):[1, 2, 3]", "tensor0 * tensor(d0[1])(1)", "tensor(d1[3]):[1, 2, 3]"); + // Add using the "expand" non-primitive function + tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor(d1[3]):[1, 2, 3]"); + tester.assertEvaluates("tensor<float>(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor<float>(d1[3]):[1, 2, 3]"); + } + + @Test + public void test() throws ParseException { + RankingExpression expr = new RankingExpression("expand(tensor<float>(d1[3]):[1,2,3], d0)"); + System.out.println(expr); + Tensor t = expr.evaluate(new MapContext()).asTensor(); + System.out.println(t); } @Test |