From bcd003d3253a5e51c19149dcc8fa44e8fd526adb Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 6 Oct 2021 10:52:56 +0200 Subject: Add non-primitive tensor expand function --- searchlib/src/main/javacc/RankingExpressionParser.jj | 15 ++++++++++++++- .../rankingexpression/evaluation/EvaluationTestCase.java | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) (limited to 'searchlib/src') diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 7bfbfd6c005..88eb0feeb73 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -144,6 +144,7 @@ TOKEN : | | | + | | | @@ -384,7 +385,8 @@ TensorFunctionNode tensorFunction() : tensorExpression = tensorXwPlusB() | tensorExpression = tensorArgmax() | tensorExpression = tensorArgmin() | - tensorExpression = tensorCellCast() + tensorExpression = tensorCellCast() | + tensorExpression = tensorExpand() ) { return tensorExpression; } } @@ -581,6 +583,16 @@ TensorFunctionNode tensorXwPlusB() : dimension)); } } +TensorFunctionNode tensorExpand() : +{ + ExpressionNode argument; + String dimension; +} +{ + argument = expression() dimension = identifier() + { return new TensorFunctionNode(new Expand(TensorFunctionNode.wrap(argument), dimension)); } +} + TensorFunctionNode tensorArgmax() : { ExpressionNode tensor; @@ -696,6 +708,7 @@ String tensorFunctionName() : ( { return token.image; } ) | ( { return token.image; } ) | ( { return token.image; } ) | + ( { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 246dbcb2b1e..ed8a15ad989 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -498,7 +498,6 @@ public class EvaluationTestCase { "tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", "tensor(d0[4]):[3,2,-1,1]"); - } @Test @@ -725,6 +724,21 @@ public class EvaluationTestCase { tester.assertEvaluates("tensor(d0[1], d1[3]):[1, 2, 3]", "tensor0 * tensor(d0[1])(1)", "tensor(d1[3]):[1, 2, 3]"); + // Add using the "expand" non-primitive function + tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor(d1[3]):[1, 2, 3]"); + tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor(d1[3]):[1, 2, 3]"); + } + + @Test + public void test() throws ParseException { + RankingExpression expr = new RankingExpression("expand(tensor(d1[3]):[1,2,3], d0)"); + System.out.println(expr); + Tensor t = expr.evaluate(new MapContext()).asTensor(); + System.out.println(t); } @Test -- cgit v1.2.3