diff options
Diffstat (limited to 'searchlib/src/main/javacc/RankingExpressionParser.jj')
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 83 |
1 files changed, 76 insertions, 7 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 6a2ce356722..0d290bf7688 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -24,7 +24,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.*; import com.yahoo.tensor.functions.*; import java.util.Collections; -import java.util.Map; import java.util.LinkedHashMap; import java.util.Arrays; import java.util.ArrayList; @@ -111,6 +110,7 @@ TOKEN : <TANH: "tanh"> | <ATAN2: "atan2"> | + <EQUAL: "equal"> | <FMOD: "fmod"> | <LDEXP: "ldexp"> | // MAX @@ -123,6 +123,11 @@ TOKEN : <JOIN: "join"> | <RENAME: "rename"> | <TENSOR: "tensor"> | + <L1_NORMALIZE: "l1_normalize"> | + <L2_NORMALIZE: "l2_normalize"> | + <MATMUL: "matmul"> | + <SOFTMAX: "softmax"> | + <XW_PLUS_B: "xw_plus_b"> | <AVG: "avg" > | <COUNT: "count"> | @@ -345,7 +350,12 @@ ExpressionNode tensorFunction() : tensorExpression = tensorReduceComposites() | tensorExpression = tensorJoin() | tensorExpression = tensorRename() | - tensorExpression = tensorGenerate() + tensorExpression = tensorGenerate() | + tensorExpression = tensorL1Normalize() | + tensorExpression = tensorL2Normalize() | + tensorExpression = tensorMatmul() | + tensorExpression = tensorSoftmax() | + tensorExpression = tensorXwPlusB() ) { return tensorExpression; } } @@ -363,7 +373,7 @@ ExpressionNode tensorMap() : ExpressionNode tensorReduce() : { ExpressionNode tensor; - ReduceFunction.Aggregator aggregator; + Reduce.Aggregator aggregator; List<String> dimensions = null; } { @@ -374,7 +384,7 @@ ExpressionNode tensorReduce() : ExpressionNode tensorReduceComposites() : { ExpressionNode tensor; - ReduceFunction.Aggregator aggregator; + Reduce.Aggregator aggregator; List<String> dimensions = null; } { @@ -417,6 +427,64 @@ ExpressionNode tensorGenerate() : { return null; } } +ExpressionNode tensorL1Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorL2Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorMatmul() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + dimension)); } +} + +ExpressionNode tensorSoftmax() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorXwPlusB() : +{ + ExpressionNode tensor1, tensor2, tensor3; + String dimension; +} +{ + <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA> + tensor2 = expression() <COMMA> + tensor3 = expression() <COMMA> + dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + TensorFunctionNode.wrapArgument(tensor3), + dimension)); } +} + LambdaFunctionNode lambdaFunction() : { List<String> variables; @@ -427,18 +495,18 @@ LambdaFunctionNode lambdaFunction() : { return new LambdaFunctionNode(variables, functionExpression); } } -ReduceFunction.Aggregator tensorReduceAggregator() : +Reduce.Aggregator tensorReduceAggregator() : { } { ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) - { return ReduceFunction.Aggregator.valueOf(token.image); } + { return Reduce.Aggregator.valueOf(token.image); } } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { - ReduceFunction.Aggregator aggregator; + Reduce.Aggregator aggregator; } { ( <F> { return token.image; } ) | @@ -481,6 +549,7 @@ Function unaryFunctionName() : { } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | + <EQUAL> { return Function.equal; } | <FMOD> { return Function.fmod; } | <LDEXP> { return Function.ldexp; } | <MAX> { return Function.max; } | |