summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/javacc/RankingExpressionParser.jj
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2016-11-26 22:45:20 +0100
committerGitHub <noreply@github.com>2016-11-26 22:45:20 +0100
commit2f55986b4de9420e5728c5abbaafb69fb2f10a34 (patch)
tree9a6a77f76d25620771dfe7ab5de49910c4321fc5 /searchlib/src/main/javacc/RankingExpressionParser.jj
parent2bc82ba9d9698214e703f19039387609d82b12f8 (diff)
Revert "Revert "Bratseth/tensor functions 3""
Diffstat (limited to 'searchlib/src/main/javacc/RankingExpressionParser.jj')
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj375
1 files changed, 276 insertions, 99 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 78ad665c414..0fcfdb5d40c 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -21,10 +21,9 @@ import com.yahoo.searchlib.rankingexpression.rule.*;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.MapTensor;
-import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.*;
+import com.yahoo.tensor.functions.*;
import java.util.Collections;
-import java.util.Map;
import java.util.LinkedHashMap;
import java.util.Arrays;
import java.util.ArrayList;
@@ -60,51 +59,83 @@ TOKEN :
<RSQUARE: "]"> |
<LCURLY: "{"> |
<RCURLY: "}"> |
+
<ADD: "+"> |
<SUB: "-"> |
<DIV: "/"> |
<MUL: "*"> |
<DOT: "."> |
+
<DOLLAR: "$"> |
<COMMA: ","> |
<COLON: ":"> |
+
<LE: "<="> |
<LT: "<"> |
<EQ: "=="> |
+ <NQ: "!="> |
<AQ: "~="> |
<GE: ">="> |
<GT: ">"> |
+
<STRING: ("\"" (~["\""] | "\\\"")* "\"") |
("'" (~["'"] | "\\'")* "'")> |
+
<IF: "if"> |
- <COSH: "cosh"> |
- <SINH: "sinh"> |
- <TANH: "tanh"> |
- <COS: "cos"> |
- <SIN: "sin"> |
- <TAN: "tan"> |
+ <IN: "in"> |
+ <F: "f"> |
+
+ <ABS: "abs"> |
<ACOS: "acos"> |
<ASIN: "asin"> |
- <ATAN2: "atan2"> |
<ATAN: "atan"> |
- <EXP: "exp"> |
- <LDEXP: "ldexp"> |
- <LOG10: "log10"> |
- <LOG: "log"> |
- <POW: "pow"> |
- <SQRT: "sqrt"> |
<CEIL: "ceil"> |
+ <COS: "cos"> |
+ <COSH: "cosh"> |
+ <ELU: "elu"> |
+ <EXP: "exp"> |
<FABS: "fabs"> |
<FLOOR: "floor"> |
- <FMOD: "fmod"> |
- <MIN: "min"> |
- <MAX: "max"> |
<ISNAN: "isNan"> |
- <IN: "in"> |
- <SUM: "sum"> |
- <MATCH: "match"> |
+ <LOG: "log"> |
+ <LOG10: "log10"> |
<RELU: "relu"> |
+ <ROUND: "round"> |
<SIGMOID: "sigmoid"> |
+ <SIGN: "sign"> |
+ <SIN: "sin"> |
+ <SINH: "sinh"> |
+ <SQUARE: "square"> |
+ <SQRT: "sqrt"> |
+ <TAN: "tan"> |
+ <TANH: "tanh"> |
+
+ <ATAN2: "atan2"> |
+ <FMOD: "fmod"> |
+ <LDEXP: "ldexp"> |
+ // MAX
+ // MIN
+ <MOD: "mod"> |
+ <POW: "pow"> |
+
+ <MAP: "map"> |
+ <REDUCE: "reduce"> |
+ <JOIN: "join"> |
+ <RENAME: "rename"> |
+ <TENSOR: "tensor"> |
+ <L1_NORMALIZE: "l1_normalize"> |
+ <L2_NORMALIZE: "l2_normalize"> |
+ <MATMUL: "matmul"> |
+ <SOFTMAX: "softmax"> |
+ <XW_PLUS_B: "xw_plus_b"> |
+
+ <AVG: "avg" > |
+ <COUNT: "count"> |
+ <PROD: "prod"> |
+ <SUM: "sum"> |
+ <MAX: "max"> |
+ <MIN: "min"> |
+
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
}
@@ -175,6 +206,7 @@ TruthOperator comparator() : { }
( <LE> { return TruthOperator.SMALLEREQUAL; } |
<LT> { return TruthOperator.SMALLER; } |
<EQ> { return TruthOperator.EQUAL; } |
+ <NQ> { return TruthOperator.NOTEQUAL; } |
<AQ> { return TruthOperator.APPROX_EQUAL; } |
<GE> { return TruthOperator.LARGEREQUAL; } |
<GT> { return TruthOperator.LARGER; } )
@@ -189,7 +221,6 @@ ExpressionNode value() :
{
( [ LOOKAHEAD(2) <SUB> { neg = true; } ]
( ret = constantPrimitive() |
- ret = constantTensor() |
LOOKAHEAD(2) ret = ifExpression() |
LOOKAHEAD(2) ret = function() |
ret = feature() |
@@ -279,7 +310,6 @@ ExpressionNode arg() :
}
{
( ret = constantPrimitive() |
- ret = constantTensor() |
LOOKAHEAD(2) ret = feature() |
name = identifier() { ret = new NameNode(name); } )
{ return ret; }
@@ -290,11 +320,11 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( function = scalarFunction() | function = tensorFunction() )
+ ( function = scalarOrTensorFunction() | function = tensorFunction() )
{ return function; }
}
-FunctionNode scalarFunction() :
+FunctionNode scalarOrTensorFunction() :
{
Function function;
ExpressionNode arg1, arg2;
@@ -312,61 +342,223 @@ FunctionNode scalarFunction() :
ExpressionNode tensorFunction() :
{
+ ExpressionNode tensorExpression;
+}
+{
+ (
+ tensorExpression = tensorMap() |
+ tensorExpression = tensorReduce() |
+ tensorExpression = tensorReduceComposites() |
+ tensorExpression = tensorJoin() |
+ tensorExpression = tensorRename() |
+ tensorExpression = tensorGenerate() |
+ tensorExpression = tensorL1Normalize() |
+ tensorExpression = tensorL2Normalize() |
+ tensorExpression = tensorMatmul() |
+ tensorExpression = tensorSoftmax() |
+ tensorExpression = tensorXwPlusB()
+ )
+ { return tensorExpression; }
+}
+
+ExpressionNode tensorMap() :
+{
+ ExpressionNode tensor;
+ LambdaFunctionNode doubleMapper;
+}
+{
+ <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
+ { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
+ doubleMapper.asDoubleUnaryOperator())); }
+}
+
+ExpressionNode tensorReduce() :
+{
+ ExpressionNode tensor;
+ Reduce.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+}
+
+ExpressionNode tensorReduceComposites() :
+{
+ ExpressionNode tensor;
+ Reduce.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ aggregator = tensorReduceAggregator()
+ <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+}
+
+ExpressionNode tensorJoin() :
+{
ExpressionNode tensor1, tensor2;
- String dimension = null;
- TensorAddress address = null;
+ LambdaFunctionNode doubleJoiner;
}
{
- (
- <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE>
- { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); }
- ) |
- (
- <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE>
- { return new TensorMatchNode(tensor1, tensor2); }
- )
+ <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
+ { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ doubleJoiner.asDoubleBinaryOperator())); }
+}
+
+ExpressionNode tensorRename() :
+{
+ ExpressionNode tensor;
+ List<String> fromDimensions, toDimensions;
+}
+{
+ <RENAME> <LBRACE> tensor = expression() <COMMA>
+ fromDimensions = bracedIdentifierList() <COMMA>
+ toDimensions = bracedIdentifierList()
+ <RBRACE>
+ { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
+}
+
+// TODO: Notice that null is parsed below
+ExpressionNode tensorGenerate() :
+{
+ TensorType type;
+ LambdaFunctionNode generator;
+}
+{
+ <TENSOR> <LBRACE> <RBRACE> <LBRACE>
+ { return new TensorFunctionNode(new Generate(null, null)); }
+}
+
+ExpressionNode tensorL1Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorL2Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorMatmul() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ dimension)); }
+}
+
+ExpressionNode tensorSoftmax() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorXwPlusB() :
+{
+ ExpressionNode tensor1, tensor2, tensor3;
+ String dimension;
+}
+{
+ <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA>
+ tensor2 = expression() <COMMA>
+ tensor3 = expression() <COMMA>
+ dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ TensorFunctionNode.wrapArgument(tensor3),
+ dimension)); }
+}
+
+LambdaFunctionNode lambdaFunction() :
+{
+ List<String> variables;
+ ExpressionNode functionExpression;
+}
+{
+ ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
+ { return new LambdaFunctionNode(variables, functionExpression); }
+}
+
+Reduce.Aggregator tensorReduceAggregator() :
+{
+}
+{
+ ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
+ { return Reduce.Aggregator.valueOf(token.image); }
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
+ Reduce.Aggregator aggregator;
}
{
- ( <SUM> | <MATCH> )
- { return token.image; }
+ ( <F> { return token.image; } ) |
+ ( <MAP> { return token.image; } ) |
+ ( <REDUCE> { return token.image; } ) |
+ ( <JOIN> { return token.image; } ) |
+ ( <RENAME> { return token.image; } ) |
+ ( <TENSOR> { return token.image; } ) |
+ ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
Function unaryFunctionName() : { }
{
- <COS> { return Function.cos; } |
- <SIN> { return Function.sin; } |
- <TAN> { return Function.tan; } |
- <COSH> { return Function.cosh; } |
- <SINH> { return Function.sinh; } |
- <TANH> { return Function.tanh; } |
+ <ABS> { return Function.abs; } |
<ACOS> { return Function.acos; } |
<ASIN> { return Function.asin; } |
<ATAN> { return Function.atan; } |
- <EXP> { return Function.exp; } |
- <LOG10> { return Function.log10; } |
- <LOG> { return Function.log; } |
- <SQRT> { return Function.sqrt; } |
<CEIL> { return Function.ceil; } |
+ <COS> { return Function.cos; } |
+ <COSH> { return Function.cosh; } |
+ <ELU> { return Function.elu; } |
+ <EXP> { return Function.exp; } |
<FABS> { return Function.fabs; } |
<FLOOR> { return Function.floor; } |
<ISNAN> { return Function.isNan; } |
+ <LOG> { return Function.log; } |
+ <LOG10> { return Function.log10; } |
<RELU> { return Function.relu; } |
- <SIGMOID> { return Function.sigmoid; }
+ <ROUND> { return Function.round; } |
+ <SIGMOID> { return Function.sigmoid; } |
+ <SIGN> { return Function.sign; } |
+ <SIN> { return Function.sin; } |
+ <SINH> { return Function.sinh; } |
+ <SQUARE> { return Function.square; } |
+ <SQRT> { return Function.sqrt; } |
+ <TAN> { return Function.tan; } |
+ <TANH> { return Function.tanh; }
}
Function binaryFunctionName() : { }
{
<ATAN2> { return Function.atan2; } |
- <LDEXP> { return Function.ldexp; } |
- <POW> { return Function.pow; } |
<FMOD> { return Function.fmod; } |
+ <LDEXP> { return Function.ldexp; } |
+ <MAX> { return Function.max; } |
<MIN> { return Function.min; } |
- <MAX> { return Function.max; }
+ <MOD> { return Function.mod; } |
+ <POW> { return Function.pow; }
}
List<ExpressionNode> expressionList() :
@@ -405,79 +597,64 @@ String identifier() :
<IDENTIFIER> { return token.image; }
}
-// An identifier or integer
-String tag() :
-{
- String name;
-}
-{
- name = identifier() { return name; } |
- <INTEGER> { return token.image; }
-}
-
-ConstantNode constantPrimitive() :
+List<String> identifierList() :
{
- String sign = "";
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
+ ( element = identifier() { list.add(element); } )?
+ ( <COMMA> element = identifier() { list.add(element); } ) *
+ { return list; }
}
-Value primitiveValue() :
+List<String> bracedIdentifierList() :
{
- String sign = "";
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( <SUB> { sign = "-";} ) ?
- ( <INTEGER> | <FLOAT> | <STRING> )
- { return Value.parse(sign + token.image); }
+ ( element = identifier() { return Collections.singletonList(element); } )
+ |
+ ( <LBRACE> list = identifierList() <RBRACE> { return list; } )
}
-ConstantNode constantTensor() :
+// An identifier or integer
+String tag() :
{
- Value constantValue;
+ String name;
}
{
- <LCURLY> constantValue = tensorContent() <RCURLY>
- { return new ConstantNode(constantValue); }
+ name = identifier() { return name; } |
+ <INTEGER> { return token.image; }
}
-TensorValue tensorContent() :
+List<String> tagCommaLeadingList() :
{
- Map<TensorAddress, Double> cells = new LinkedHashMap<TensorAddress, Double>();
- TensorAddress address;
- Double value;
+ List<String> list = new ArrayList<String>();
+ String element;
}
{
- ( address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) ?
- ( <COMMA> address = tensorAddress() <COLON> value = number() { cells.put(address, value); } ) *
- { return new TensorValue(new MapTensor(cells)); }
+ ( <COMMA> element = tag() { list.add(element); } ) *
+ { return list; }
}
-TensorAddress tensorAddress() :
+ConstantNode constantPrimitive() :
{
- List<TensorAddress.Element> elements = new ArrayList<TensorAddress.Element>();
- String dimension;
- String label;
+ String sign = "";
}
{
- <LCURLY>
- ( dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) ?
- ( <COMMA> dimension = tag() <COLON> label = label() { elements.add(new TensorAddress.Element(dimension, label)); } ) *
- <RCURLY>
- { return TensorAddress.fromUnsorted(elements); }
+ ( <SUB> { sign = "-";} ) ?
+ ( <INTEGER> | <FLOAT> | <STRING> )
+ { return new ConstantNode(Value.parse(sign + token.image),sign + token.image); }
}
-String label() :
+Value primitiveValue() :
{
- String label;
-
+ String sign = "";
}
{
- ( label = tag() |
- ( "-" { label = "-"; } ) )
- { return label; }
+ ( <SUB> { sign = "-";} ) ?
+ ( <INTEGER> | <FLOAT> | <STRING> )
+ { return Value.parse(sign + token.image); }
}
-