summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/javacc
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 15:14:34 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 15:14:34 +0100
commitf23a86c355c0b9a66a611bb2ca96edeff50bfc7b (patch)
tree9f6338e9bae9feab0369d1868609ae98415b7892 /searchlib/src/main/javacc
parent063a290e2bc16502e7cf691d29f3105c07cb768c (diff)
Add tensor argmax and argmin
Diffstat (limited to 'searchlib/src/main/javacc')
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj28
1 files changed, 27 insertions, 1 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index fab80304f6d..67da9c59432 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -131,6 +131,8 @@ TOKEN :
<MATMUL: "matmul"> |
<SOFTMAX: "softmax"> |
<XW_PLUS_B: "xw_plus_b"> |
+ <ARGMAX: "argmax"> |
+ <ARGMIN: "argmin"> |
<AVG: "avg" > |
<COUNT: "count"> |
@@ -362,7 +364,9 @@ ExpressionNode tensorFunction() :
tensorExpression = tensorL2Normalize() |
tensorExpression = tensorMatmul() |
tensorExpression = tensorSoftmax() |
- tensorExpression = tensorXwPlusB()
+ tensorExpression = tensorXwPlusB() |
+ tensorExpression = tensorArgmax() |
+ tensorExpression = tensorArgmin()
)
{ return tensorExpression; }
}
@@ -521,6 +525,26 @@ ExpressionNode tensorXwPlusB() :
dimension)); }
}
+ExpressionNode tensorArgmax() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorArgmin() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
LambdaFunctionNode lambdaFunction() :
{
List<String> variables;
@@ -582,6 +606,8 @@ String tensorFunctionName() :
( <MATMUL> { return token.image; } ) |
( <SOFTMAX> { return token.image; } ) |
( <XW_PLUS_B> { return token.image; } ) |
+ ( <ARGMAX> { return token.image; } ) |
+ ( <ARGMIN> { return token.image; } ) |
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}