diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-22 15:14:34 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-22 15:14:34 +0100 |
commit | f23a86c355c0b9a66a611bb2ca96edeff50bfc7b (patch) | |
tree | 9f6338e9bae9feab0369d1868609ae98415b7892 /searchlib/src/main/javacc | |
parent | 063a290e2bc16502e7cf691d29f3105c07cb768c (diff) |
Add tensor argmax and argmin
Diffstat (limited to 'searchlib/src/main/javacc')
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index fab80304f6d..67da9c59432 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -131,6 +131,8 @@ TOKEN : <MATMUL: "matmul"> | <SOFTMAX: "softmax"> | <XW_PLUS_B: "xw_plus_b"> | + <ARGMAX: "argmax"> | + <ARGMIN: "argmin"> | <AVG: "avg" > | <COUNT: "count"> | @@ -362,7 +364,9 @@ ExpressionNode tensorFunction() : tensorExpression = tensorL2Normalize() | tensorExpression = tensorMatmul() | tensorExpression = tensorSoftmax() | - tensorExpression = tensorXwPlusB() + tensorExpression = tensorXwPlusB() | + tensorExpression = tensorArgmax() | + tensorExpression = tensorArgmin() ) { return tensorExpression; } } @@ -521,6 +525,26 @@ ExpressionNode tensorXwPlusB() : dimension)); } } +ExpressionNode tensorArgmax() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorArgmin() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + LambdaFunctionNode lambdaFunction() : { List<String> variables; @@ -582,6 +606,8 @@ String tensorFunctionName() : ( <MATMUL> { return token.image; } ) | ( <SOFTMAX> { return token.image; } ) | ( <XW_PLUS_B> { return token.image; } ) | + ( <ARGMAX> { return token.image; } ) | + ( <ARGMIN> { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } |