diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-22 10:46:03 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-22 10:46:03 +0100 |
commit | 0303d397a8b00c48a576a5b0f011c2d6986b571c (patch) | |
tree | cb21ccab57a215bf1405c465f6ff924d4f98c3e2 /searchlib/src/main | |
parent | db153836b44fefb04bfa3c32ec5e0499c7141e37 (diff) |
More functions
Diffstat (limited to 'searchlib/src/main')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java) | 36 | ||||
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 71 |
2 files changed, 75 insertions, 32 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java index a1f83157e20..4e4095cb86e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java @@ -5,11 +5,11 @@ import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.functions.ReduceFunction; import java.util.Collections; import java.util.Deque; import java.util.List; -import java.util.Optional; /** * A node which sums over all cells in the argument tensor @@ -17,17 +17,20 @@ import java.util.Optional; * @author bratseth */ @Beta -public class TensorSumNode extends CompositeNode { +public class TensorReduceNode extends CompositeNode { - /** The tensor to sum */ + /** The tensor to aggregate over */ private final ExpressionNode argument; - /** The dimension to sum over, or empty to sum all cells to a scalar */ - private final Optional<String> dimension; + private final ReduceFunction.Aggregator aggregator; - public TensorSumNode(ExpressionNode argument, Optional<String> dimension) { + /** The dimensions to sum over, or empty to sum all cells */ + private final List<String> dimensions; + + public TensorReduceNode(ExpressionNode argument, ReduceFunction.Aggregator aggregator, List<String> dimensions) { this.argument = argument; - this.dimension = dimension; + this.aggregator = aggregator; + this.dimensions = dimensions; } @Override @@ -38,15 +41,19 @@ public class TensorSumNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument"); - return new TensorSumNode(children.get(0), dimension); + return new TensorReduceNode(children.get(0), aggregator, dimensions); } @Override public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - return "sum(" + - argument.toString(context, path, parent) + - ( dimension.isPresent() ? ", " + dimension.get() : "" ) + - ")"; + return "reduce(" + argument.toString(context, path, parent) + ", \"" + aggregator + "\"" + commaSeparated(dimensions) + ")"; + } + + private String commaSeparated(List<String> list) { + StringBuilder b = new StringBuilder(); + for (String element : list) + b.append(", ").append(element); + return b.toString(); } @Override @@ -56,10 +63,7 @@ public class TensorSumNode extends CompositeNode { throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " + "but this returns " + argumentValue + ", not a tensor"); TensorValue tensorArgument = (TensorValue)argumentValue; - if (dimension.isPresent()) - return tensorArgument.sum(dimension.get()); - else - return tensorArgument.sum(); + return new TensorValue(tensorArgument.asTensor().reduce(aggregator, dimensions)); } } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 17aefd5f0a0..a800028d00b 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -23,6 +23,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.MapTensor; import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.functions.*; import java.util.Collections; import java.util.Map; import java.util.LinkedHashMap; @@ -97,12 +98,15 @@ TOKEN : <FABS: "fabs"> | <FLOOR: "floor"> | <FMOD: "fmod"> | - <MIN: "min"> | - <MAX: "max"> | <ISNAN: "isNan"> | <IN: "in"> | + <REDUCE: "reduce"> | + <AVG: "avg" > | + <COUNT: "count"> | + <PROD: "prod"> | <SUM: "sum"> | - <MATCH: "match"> | + <MAX: "max"> | + <MIN: "min"> | <RELU: "relu"> | <SIGMOID: "sigmoid"> | <IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)> @@ -310,28 +314,53 @@ FunctionNode scalarFunction() : ExpressionNode tensorFunction() : { - ExpressionNode tensor1, tensor2; - String dimension = null; - TensorAddress address = null; + ExpressionNode tensorExpression; } { - ( - <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE> - { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); } - ) | - ( - <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE> - { return new TensorMatchNode(tensor1, tensor2); } - ) + ( tensorExpression = tensorPrimitiveReduce() | tensorExpression = tensorReduce() ) + { return tensorExpression; } +} + +ExpressionNode tensorPrimitiveReduce() : +{ + ExpressionNode tensor1; + ReduceFunction.Aggregator aggregator; + List<String> dimensions = null; +} +{ + <REDUCE> <LBRACE> tensor1 = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorReduceNode(tensor1, aggregator, dimensions); } +} + +ExpressionNode tensorReduce() : +{ + ExpressionNode tensor1; + ReduceFunction.Aggregator aggregator; + List<String> dimensions = null; +} +{ + aggregator = tensorReduceAggregator() + <LBRACE> tensor1 = expression() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorReduceNode(tensor1, aggregator, dimensions); } +} + +ReduceFunction.Aggregator tensorReduceAggregator() : +{ +} +{ + ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> ) + { return ReduceFunction.Aggregator.valueOf(token.image); } } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { + ReduceFunction.Aggregator aggregator; } { - ( <SUM> | <MATCH> ) - { return token.image; } + ( <REDUCE> { return token.image; } ) + | + ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } Function unaryFunctionName() : { } @@ -413,6 +442,16 @@ String tag() : <INTEGER> { return token.image; } } +List<String> tagCommaLeadingList() : +{ + List<String> list = new ArrayList<String>(); + String element; +} +{ + ( <COMMA> element = tag() { list.add(element); } ) * + { return list; } +} + ConstantNode constantPrimitive() : { String sign = ""; |