summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/javacc/RankingExpressionParser.jj
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/javacc/RankingExpressionParser.jj')
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj83
1 files changed, 76 insertions, 7 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 6a2ce356722..0d290bf7688 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -24,7 +24,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.*;
import com.yahoo.tensor.functions.*;
import java.util.Collections;
-import java.util.Map;
import java.util.LinkedHashMap;
import java.util.Arrays;
import java.util.ArrayList;
@@ -111,6 +110,7 @@ TOKEN :
<TANH: "tanh"> |
<ATAN2: "atan2"> |
+ <EQUAL: "equal"> |
<FMOD: "fmod"> |
<LDEXP: "ldexp"> |
// MAX
@@ -123,6 +123,11 @@ TOKEN :
<JOIN: "join"> |
<RENAME: "rename"> |
<TENSOR: "tensor"> |
+ <L1_NORMALIZE: "l1_normalize"> |
+ <L2_NORMALIZE: "l2_normalize"> |
+ <MATMUL: "matmul"> |
+ <SOFTMAX: "softmax"> |
+ <XW_PLUS_B: "xw_plus_b"> |
<AVG: "avg" > |
<COUNT: "count"> |
@@ -345,7 +350,12 @@ ExpressionNode tensorFunction() :
tensorExpression = tensorReduceComposites() |
tensorExpression = tensorJoin() |
tensorExpression = tensorRename() |
- tensorExpression = tensorGenerate()
+ tensorExpression = tensorGenerate() |
+ tensorExpression = tensorL1Normalize() |
+ tensorExpression = tensorL2Normalize() |
+ tensorExpression = tensorMatmul() |
+ tensorExpression = tensorSoftmax() |
+ tensorExpression = tensorXwPlusB()
)
{ return tensorExpression; }
}
@@ -363,7 +373,7 @@ ExpressionNode tensorMap() :
ExpressionNode tensorReduce() :
{
ExpressionNode tensor;
- ReduceFunction.Aggregator aggregator;
+ Reduce.Aggregator aggregator;
List<String> dimensions = null;
}
{
@@ -374,7 +384,7 @@ ExpressionNode tensorReduce() :
ExpressionNode tensorReduceComposites() :
{
ExpressionNode tensor;
- ReduceFunction.Aggregator aggregator;
+ Reduce.Aggregator aggregator;
List<String> dimensions = null;
}
{
@@ -417,6 +427,64 @@ ExpressionNode tensorGenerate() :
{ return null; }
}
+ExpressionNode tensorL1Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorL2Normalize() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorMatmul() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ dimension)); }
+}
+
+ExpressionNode tensorSoftmax() :
+{
+ ExpressionNode tensor;
+ String dimension;
+}
+{
+ <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+}
+
+ExpressionNode tensorXwPlusB() :
+{
+ ExpressionNode tensor1, tensor2, tensor3;
+ String dimension;
+}
+{
+ <XW_PLUS_B> <LBRACE> tensor1 = expression() <COMMA>
+ tensor2 = expression() <COMMA>
+ tensor3 = expression() <COMMA>
+ dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
+ TensorFunctionNode.wrapArgument(tensor2),
+ TensorFunctionNode.wrapArgument(tensor3),
+ dimension)); }
+}
+
LambdaFunctionNode lambdaFunction() :
{
List<String> variables;
@@ -427,18 +495,18 @@ LambdaFunctionNode lambdaFunction() :
{ return new LambdaFunctionNode(variables, functionExpression); }
}
-ReduceFunction.Aggregator tensorReduceAggregator() :
+Reduce.Aggregator tensorReduceAggregator() :
{
}
{
( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
- { return ReduceFunction.Aggregator.valueOf(token.image); }
+ { return Reduce.Aggregator.valueOf(token.image); }
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
- ReduceFunction.Aggregator aggregator;
+ Reduce.Aggregator aggregator;
}
{
( <F> { return token.image; } ) |
@@ -481,6 +549,7 @@ Function unaryFunctionName() : { }
Function binaryFunctionName() : { }
{
<ATAN2> { return Function.atan2; } |
+ <EQUAL> { return Function.equal; } |
<FMOD> { return Function.fmod; } |
<LDEXP> { return Function.ldexp; } |
<MAX> { return Function.max; } |