summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-10-06 10:52:56 +0200
committerLester Solbakken <lesters@oath.com>2021-10-06 10:52:56 +0200
commitbcd003d3253a5e51c19149dcc8fa44e8fd526adb (patch)
tree8c8505ae2a075996a4724da106337262398ad72e /searchlib
parent4de0026c1065403d028d7157abb571830603e6c9 (diff)
Add non-primitive tensor expand function
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj15
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java16
3 files changed, 31 insertions, 2 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 2468fd0c5c7..4ebca94734f 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -898,6 +898,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorExpand()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()",
@@ -1053,6 +1054,7 @@
"public static final int ARGMAX",
"public static final int ARGMIN",
"public static final int CELL_CAST",
+ "public static final int EXPAND",
"public static final int AVG",
"public static final int COUNT",
"public static final int MAX",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 7bfbfd6c005..88eb0feeb73 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -144,6 +144,7 @@ TOKEN :
<ARGMAX: "argmax"> |
<ARGMIN: "argmin"> |
<CELL_CAST: "cell_cast"> |
+ <EXPAND: "expand"> |
<AVG: "avg" > |
<COUNT: "count"> |
@@ -384,7 +385,8 @@ TensorFunctionNode tensorFunction() :
tensorExpression = tensorXwPlusB() |
tensorExpression = tensorArgmax() |
tensorExpression = tensorArgmin() |
- tensorExpression = tensorCellCast()
+ tensorExpression = tensorCellCast() |
+ tensorExpression = tensorExpand()
)
{ return tensorExpression; }
}
@@ -581,6 +583,16 @@ TensorFunctionNode tensorXwPlusB() :
dimension)); }
}
+TensorFunctionNode tensorExpand() :
+{
+ ExpressionNode argument;
+ String dimension;
+}
+{
+ <EXPAND> <LBRACE> argument = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Expand(TensorFunctionNode.wrap(argument), dimension)); }
+}
+
TensorFunctionNode tensorArgmax() :
{
ExpressionNode tensor;
@@ -696,6 +708,7 @@ String tensorFunctionName() :
( <ARGMAX> { return token.image; } ) |
( <ARGMIN> { return token.image; } ) |
( <CELL_CAST> { return token.image; } ) |
+ ( <EXPAND> { return token.image; } ) |
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 246dbcb2b1e..ed8a15ad989 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -498,7 +498,6 @@ public class EvaluationTestCase {
"tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })",
"tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]",
"tensor(d0[4]):[3,2,-1,1]");
-
}
@Test
@@ -725,6 +724,21 @@ public class EvaluationTestCase {
tester.assertEvaluates("tensor(d0[1], d1[3]):[1, 2, 3]",
"tensor0 * tensor(d0[1])(1)",
"tensor(d1[3]):[1, 2, 3]");
+ // Add using the "expand" non-primitive function
+ tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]",
+ "expand(tensor0, d0)",
+ "tensor(d1[3]):[1, 2, 3]");
+ tester.assertEvaluates("tensor<float>(d0[1],d1[3]):[[1,2,3]]",
+ "expand(tensor0, d0)",
+ "tensor<float>(d1[3]):[1, 2, 3]");
+ }
+
+ @Test
+ public void test() throws ParseException {
+ RankingExpression expr = new RankingExpression("expand(tensor<float>(d1[3]):[1,2,3], d0)");
+ System.out.println(expr);
+ Tensor t = expr.evaluate(new MapContext()).asTensor();
+ System.out.println(t);
}
@Test