diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2016-11-26 22:45:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-26 22:45:20 +0100 |
commit | 2f55986b4de9420e5728c5abbaafb69fb2f10a34 (patch) | |
tree | 9a6a77f76d25620771dfe7ab5de49910c4321fc5 /searchlib/src/main/javacc/RankingExpressionParser.jj | |
parent | 2bc82ba9d9698214e703f19039387609d82b12f8 (diff) |
Revert "Revert "Bratseth/tensor functions 3""
Diffstat (limited to 'searchlib/src/main/javacc/RankingExpressionParser.jj')
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 375 |
1 files changed, 276 insertions, 99 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 78ad665c414..0fcfdb5d40c 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -21,10 +21,9 @@ import com.yahoo.searchlib.rankingexpression.rule.*; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.tensor.MapTensor; -import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.*; +import com.yahoo.tensor.functions.*; import java.util.Collections; -import java.util.Map; import java.util.LinkedHashMap; import java.util.Arrays; import java.util.ArrayList; @@ -60,51 +59,83 @@ TOKEN : <RSQUARE: "]"> | <LCURLY: "{"> | <RCURLY: "}"> | + <ADD: "+"> | <SUB: "-"> | <DIV: "/"> | <MUL: "*"> | <DOT: "."> | + <DOLLAR: "$"> | <COMMA: ","> | <COLON: ":"> | + <LE: "<="> | <LT: "<"> | <EQ: "=="> | + <NQ: "!="> | <AQ: "~="> | <GE: ">="> | <GT: ">"> | + <STRING: ("\"" (~["\""] | "\\\"")* "\"") | ("'" (~["'"] | "\\'")* "'")> | + <IF: "if"> | - <COSH: "cosh"> | - <SINH: "sinh"> | - <TANH: "tanh"> | - <COS: "cos"> | - <SIN: "sin"> | - <TAN: "tan"> | + <IN: "in"> | + <F: "f"> | + + <ABS: "abs"> | <ACOS: "acos"> | <ASIN: "asin"> | - <ATAN2: "atan2"> | <ATAN: "atan"> | - <EXP: "exp"> | - <LDEXP: "ldexp"> | - <LOG10: "log10"> | - <LOG: "log"> | - <POW: "pow"> | - <SQRT: "sqrt"> | <CEIL: "ceil"> | + <COS: "cos"> | + <COSH: "cosh"> | + <ELU: "elu"> | + <EXP: "exp"> | <FABS: "fabs"> | <FLOOR: "floor"> | - <FMOD: "fmod"> | - <MIN: "min"> | - <MAX: "max"> | <ISNAN: "isNan"> | - <IN: "in"> | - <SUM: "sum"> | - <MATCH: "match"> | + <LOG: "log"> | + <LOG10: "log10"> | <RELU: "relu"> | + <ROUND: "round"> | <SIGMOID: "sigmoid"> | + <SIGN: "sign"> | + <SIN: "sin"> | + <SINH: "sinh"> | + <SQUARE: "square"> | + <SQRT: "sqrt"> | + <TAN: "tan"> | + <TANH: "tanh"> | + + <ATAN2: "atan2"> | + <FMOD: "fmod"> | + <LDEXP: "ldexp"> | + // MAX + // MIN + <MOD: "mod"> | + <POW: "pow"> | + + <MAP: "map"> | + <REDUCE: "reduce"> | + <JOIN: "join"> | + <RENAME: "rename"> | + <TENSOR: "tensor"> | + <L1_NORMALIZE: "l1_normalize"> | + <L2_NORMALIZE: "l2_normalize"> | + <MATMUL: "matmul"> | + <SOFTMAX: "softmax"> | + <XW_PLUS_B: "xw_plus_b"> | + + <AVG: "avg" > | + <COUNT: "count"> | + <PROD: "prod"> | + <SUM: "sum"> | + <MAX: "max"> | + <MIN: "min"> | + <IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)> } @@ -175,6 +206,7 @@ TruthOperator comparator() : { } ( <LE> { return TruthOperator.SMALLEREQUAL; } | <LT> { return TruthOperator.SMALLER; } | <EQ> { return TruthOperator.EQUAL; } | + <NQ> { return TruthOperator.NOTEQUAL; } | <AQ> { return TruthOperator.APPROX_EQUAL; } | <GE> { return TruthOperator.LARGEREQUAL; } | <GT> { return TruthOperator.LARGER; } ) @@ -189,7 +221,6 @@ ExpressionNode value() : { ( [ LOOKAHEAD(2) <SUB> { neg = true; } ] ( ret = constantPrimitive() | - ret = constantTensor() | LOOKAHEAD(2) ret = ifExpression() | LOOKAHEAD(2) ret = function() | ret = feature() | @@ -279,7 +310,6 @@ ExpressionNode arg() : } { ( ret = constantPrimitive() | - ret = constantTensor() | LOOKAHEAD(2) ret = feature() | name = identifier() { ret = new NameNode(name); } ) { return ret; } @@ -290,11 +320,11 @@ ExpressionNode function() : ExpressionNode function; } { - ( function = scalarFunction() | function = tensorFunction() ) + ( function = scalarOrTensorFunction() | function = tensorFunction() ) { return function; } } -FunctionNode scalarFunction() : +FunctionNode scalarOrTensorFunction() : { Function function; ExpressionNode arg1, arg2; @@ -312,61 +342,223 @@ FunctionNode scalarFunction() : ExpressionNode tensorFunction() : { + ExpressionNode tensorExpression; +} +{ + ( + tensorExpression = tensorMap() | + tensorExpression = tensorReduce() | + tensorExpression = tensorReduceComposites() | + tensorExpression = tensorJoin() | + tensorExpression = tensorRename() | + tensorExpression = tensorGenerate() | + tensorExpression = tensorL1Normalize() | + tensorExpression = tensorL2Normalize() | + tensorExpression = tensorMatmul() | + tensorExpression = tensorSoftmax() | + tensorExpression = tensorXwPlusB() + ) + { return tensorExpression; } +} + +ExpressionNode tensorMap() : +{ + ExpressionNode tensor; + LambdaFunctionNode doubleMapper; +} +{ + <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> + { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), + doubleMapper.asDoubleUnaryOperator())); } +} + +ExpressionNode tensorReduce() : +{ + ExpressionNode tensor; + Reduce.Aggregator aggregator; + List<String> dimensions = null; +} +{ + <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } +} + +ExpressionNode tensorReduceComposites() : +{ + ExpressionNode tensor; + Reduce.Aggregator aggregator; + List<String> dimensions = null; +} +{ + aggregator = tensorReduceAggregator() + <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } +} + +ExpressionNode tensorJoin() : +{ ExpressionNode tensor1, tensor2; - String dimension = null; - TensorAddress address = null; + LambdaFunctionNode doubleJoiner; } { - ( - <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE> - { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); } - ) | - ( - <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE> - { return new TensorMatchNode(tensor1, tensor2); } - ) + <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> + { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + doubleJoiner.asDoubleBinaryOperator())); } +} + +ExpressionNode tensorRename() : +{ + ExpressionNode tensor; + List<String> fromDimensions, toDimensions; +} +{ + <RENAME> <LBRACE> tensor = expression() <COMMA> + fromDimensions = bracedIdentifierList() <COMMA> + toDimensions = bracedIdentifierList() + <RBRACE> + { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } +} + +// TODO: Notice that null is parsed below +ExpressionNode tensorGenerate() : +{ + TensorType type; + LambdaFunctionNode generator; +} +{ + <TENSOR> <LBRACE> <RBRACE> <LBRACE> + { return new TensorFunctionNode(new Generate(null, null)); } +} + +ExpressionNode tensorL1Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorL2Normalize() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorMatmul() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + dimension)); } +} + +ExpressionNode tensorSoftmax() : +{ + ExpressionNode tensor; + String dimension; +} +{ + <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } +} + +ExpressionNode tensorXwPlusB() : +{ + ExpressionNode tensor1, tensor2, tensor3; + String dimension; +} +{ + <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA> + tensor2 = expression() <COMMA> + tensor3 = expression() <COMMA> + dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), + TensorFunctionNode.wrapArgument(tensor2), + TensorFunctionNode.wrapArgument(tensor3), + dimension)); } +} + +LambdaFunctionNode lambdaFunction() : +{ + List<String> variables; + ExpressionNode functionExpression; +} +{ + ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) + { return new LambdaFunctionNode(variables, functionExpression); } +} + +Reduce.Aggregator tensorReduceAggregator() : +{ +} +{ + ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) + { return Reduce.Aggregator.valueOf(token.image); } } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { + Reduce.Aggregator aggregator; } { - ( <SUM> | <MATCH> ) - { return token.image; } + ( <F> { return token.image; } ) | + ( <MAP> { return token.image; } ) | + ( <REDUCE> { return token.image; } ) | + ( <JOIN> { return token.image; } ) | + ( <RENAME> { return token.image; } ) | + ( <TENSOR> { return token.image; } ) | + ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } Function unaryFunctionName() : { } { - <COS> { return Function.cos; } | - <SIN> { return Function.sin; } | - <TAN> { return Function.tan; } | - <COSH> { return Function.cosh; } | - <SINH> { return Function.sinh; } | - <TANH> { return Function.tanh; } | + <ABS> { return Function.abs; } | <ACOS> { return Function.acos; } | <ASIN> { return Function.asin; } | <ATAN> { return Function.atan; } | - <EXP> { return Function.exp; } | - <LOG10> { return Function.log10; } | - <LOG> { return Function.log; } | - <SQRT> { return Function.sqrt; } | <CEIL> { return Function.ceil; } | + <COS> { return Function.cos; } | + <COSH> { return Function.cosh; } | + <ELU> { return Function.elu; } | + <EXP> { return Function.exp; } | <FABS> { return Function.fabs; } | <FLOOR> { return Function.floor; } | <ISNAN> { return Function.isNan; } | + <LOG> { return Function.log; } | + <LOG10> { return Function.log10; } | <RELU> { return Function.relu; } | - <SIGMOID> { return Function.sigmoid; } + <ROUND> { return Function.round; } | + <SIGMOID> { return Function.sigmoid; } | + <SIGN> { return Function.sign; } | + <SIN> { return Function.sin; } | + <SINH> { return Function.sinh; } | + <SQUARE> { return Function.square; } | + <SQRT> { return Function.sqrt; } | + <TAN> { return Function.tan; } | + <TANH> { return Function.tanh; } } Function binaryFunctionName() : { } { <ATAN2> { return Function.atan2; } | - <LDEXP> { return Function.ldexp; } | - <POW> { return Function.pow; } | <FMOD> { return Function.fmod; } | + <LDEXP> { return Function.ldexp; } | + <MAX> { return Function.max; } | <MIN> { return Function.min; } | - <MAX> { return Function.max; } + <MOD> { return Function.mod; } | + <POW> { return Function.pow; } } List<ExpressionNode> expressionList() : @@ -405,79 +597,64 @@ String identifier() : <IDENTIFIER> { return token.image; } } -// An identifier or integer -String tag() : -{ - String name; -} -{ - name = identifier() { return name; } | - <INTEGER> { return token.image; } -} - -ConstantNode constantPrimitive() : +List<String> identifierList() : { - String sign = ""; + List<String> list = new ArrayList<String>(); + String element; } { - ( <SUB> { sign = "-";} ) ? - ( <INTEGER> | <FLOAT> | <STRING> ) - { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); } + ( element = identifier() { list.add(element); } )? + ( <COMMA> element = identifier() { list.add(element); } ) * + { return list; } } -Value primitiveValue() : +List<String> bracedIdentifierList() : { - String sign = ""; + List<String> list = new ArrayList<String>(); + String element; } { - ( <SUB> { sign = "-";} ) ? - ( <INTEGER> | <FLOAT> | <STRING> ) - { return Value.parse(sign + token.image); } + ( element = identifier() { return Collections.singletonList(element); } ) + | + ( <LBRACE> list = identifierList() <RBRACE> { return list; } ) } -ConstantNode constantTensor() : +// An identifier or integer +String tag() : { - Value constantValue; + String name; } { - <LCURLY> constantValue = tensorContent() <RCURLY> - { return new ConstantNode(constantValue); } + name = identifier() { return name; } | + <INTEGER> { return token.image; } } -TensorValue tensorContent() : +List<String> tagCommaLeadingList() : { - Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>(); - TensorAddress address; - Double value; + List<String> list = new ArrayList<String>(); + String element; } { - ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ? - ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) * - { return new TensorValue(new MapTensor(cells)); } + ( <COMMA> element = tag() { list.add(element); } ) * + { return list; } } -TensorAddress tensorAddress() : +ConstantNode constantPrimitive() : { - List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>(); - String dimension; - String label; + String sign = ""; } { - <LCURLY> - ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ? - ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) * - <RCURLY> - { return TensorAddress.fromUnsorted(elements); } + ( <SUB> { sign = "-";} ) ? + ( <INTEGER> | <FLOAT> | <STRING> ) + { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); } } -String label() : +Value primitiveValue() : { - String label; - + String sign = ""; } { - ( label = tag() | - ( "-" { label = "-"; } ) ) - { return label; } + ( <SUB> { sign = "-";} ) ? + ( <INTEGER> | <FLOAT> | <STRING> ) + { return Value.parse(sign + token.image); } } - |