diff options
author | Lester Solbakken <lesters@oath.com> | 2021-10-06 10:52:56 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-10-06 10:52:56 +0200 |
commit | bcd003d3253a5e51c19149dcc8fa44e8fd526adb (patch) | |
tree | 8c8505ae2a075996a4724da106337262398ad72e /searchlib | |
parent | 4de0026c1065403d028d7157abb571830603e6c9 (diff) |
Add non-primitive tensor expand function
Diffstat (limited to 'searchlib')
3 files changed, 31 insertions, 2 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 2468fd0c5c7..4ebca94734f 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -898,6 +898,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorExpand()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()", @@ -1053,6 +1054,7 @@ "public static final int ARGMAX", "public static final int ARGMIN", "public static final int CELL_CAST", + "public static final int EXPAND", "public static final int AVG", "public static final int COUNT", "public static final int MAX", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 7bfbfd6c005..88eb0feeb73 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -144,6 +144,7 @@ TOKEN : <ARGMAX: "argmax"> | <ARGMIN: "argmin"> | <CELL_CAST: "cell_cast"> | + <EXPAND: "expand"> | <AVG: "avg" > | <COUNT: "count"> | @@ -384,7 +385,8 @@ TensorFunctionNode tensorFunction() : tensorExpression = tensorXwPlusB() | tensorExpression = tensorArgmax() | tensorExpression = tensorArgmin() | - tensorExpression = tensorCellCast() + tensorExpression = tensorCellCast() | + tensorExpression = tensorExpand() ) { return tensorExpression; } } @@ -581,6 +583,16 @@ TensorFunctionNode tensorXwPlusB() : dimension)); } } +TensorFunctionNode tensorExpand() : +{ + ExpressionNode argument; + String dimension; +} +{ + <EXPAND> <LBRACE> argument = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Expand(TensorFunctionNode.wrap(argument), dimension)); } +} + TensorFunctionNode tensorArgmax() : { ExpressionNode tensor; @@ -696,6 +708,7 @@ String tensorFunctionName() : ( <ARGMAX> { return token.image; } ) | ( <ARGMIN> { return token.image; } ) | ( <CELL_CAST> { return token.image; } ) | + ( <EXPAND> { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 246dbcb2b1e..ed8a15ad989 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -498,7 +498,6 @@ public class EvaluationTestCase { "tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", "tensor(d0[4]):[3,2,-1,1]"); - } @Test @@ -725,6 +724,21 @@ public class EvaluationTestCase { tester.assertEvaluates("tensor(d0[1], d1[3]):[1, 2, 3]", "tensor0 * tensor(d0[1])(1)", "tensor(d1[3]):[1, 2, 3]"); + // Add using the "expand" non-primitive function + tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor(d1[3]):[1, 2, 3]"); + tester.assertEvaluates("tensor<float>(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor<float>(d1[3]):[1, 2, 3]"); + } + + @Test + public void test() throws ParseException { + RankingExpression expr = new RankingExpression("expand(tensor<float>(d1[3]):[1,2,3], d0)"); + System.out.println(expr); + Tensor t = expr.evaluate(new MapContext()).asTensor(); + System.out.println(t); } @Test |