summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-22 10:46:03 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-22 10:46:03 +0100
commit0303d397a8b00c48a576a5b0f011c2d6986b571c (patch)
treecb21ccab57a215bf1405c465f6ff924d4f98c3e2 /searchlib/src/main
parentdb153836b44fefb04bfa3c32ec5e0499c7141e37 (diff)
More functions
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java)36
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj71
2 files changed, 75 insertions, 32 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
index a1f83157e20..4e4095cb86e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorSumNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
@@ -5,11 +5,11 @@ import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.functions.ReduceFunction;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
-import java.util.Optional;
/**
* A node which sums over all cells in the argument tensor
@@ -17,17 +17,20 @@ import java.util.Optional;
* @author bratseth
*/
@Beta
-public class TensorSumNode extends CompositeNode {
+public class TensorReduceNode extends CompositeNode {
- /** The tensor to sum */
+ /** The tensor to aggregate over */
private final ExpressionNode argument;
- /** The dimension to sum over, or empty to sum all cells to a scalar */
- private final Optional<String> dimension;
+ private final ReduceFunction.Aggregator aggregator;
- public TensorSumNode(ExpressionNode argument, Optional<String> dimension) {
+ /** The dimensions to sum over, or empty to sum all cells */
+ private final List<String> dimensions;
+
+ public TensorReduceNode(ExpressionNode argument, ReduceFunction.Aggregator aggregator, List<String> dimensions) {
this.argument = argument;
- this.dimension = dimension;
+ this.aggregator = aggregator;
+ this.dimensions = dimensions;
}
@Override
@@ -38,15 +41,19 @@ public class TensorSumNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
if (children.size() != 1) throw new IllegalArgumentException("A tensor sum node must have one tensor argument");
- return new TensorSumNode(children.get(0), dimension);
+ return new TensorReduceNode(children.get(0), aggregator, dimensions);
}
@Override
public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
- return "sum(" +
- argument.toString(context, path, parent) +
- ( dimension.isPresent() ? ", " + dimension.get() : "" ) +
- ")";
+ return "reduce(" + argument.toString(context, path, parent) + ", \"" + aggregator + "\"" + commaSeparated(dimensions) + ")";
+ }
+
+ private String commaSeparated(List<String> list) {
+ StringBuilder b = new StringBuilder();
+ for (String element : list)
+ b.append(", ").append(element);
+ return b.toString();
}
@Override
@@ -56,10 +63,7 @@ public class TensorSumNode extends CompositeNode {
throw new IllegalArgumentException("Attempted to take the tensor sum of argument '" + argument + "', " +
"but this returns " + argumentValue + ", not a tensor");
TensorValue tensorArgument = (TensorValue)argumentValue;
- if (dimension.isPresent())
- return tensorArgument.sum(dimension.get());
- else
- return tensorArgument.sum();
+ return new TensorValue(tensorArgument.asTensor().reduce(aggregator, dimensions));
}
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 17aefd5f0a0..a800028d00b 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -23,6 +23,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.MapTensor;
import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.functions.*;
import java.util.Collections;
import java.util.Map;
import java.util.LinkedHashMap;
@@ -97,12 +98,15 @@ TOKEN :
<FABS: "fabs"> |
<FLOOR: "floor"> |
<FMOD: "fmod"> |
- <MIN: "min"> |
- <MAX: "max"> |
<ISNAN: "isNan"> |
<IN: "in"> |
+ <REDUCE: "reduce"> |
+ <AVG: "avg" > |
+ <COUNT: "count"> |
+ <PROD: "prod"> |
<SUM: "sum"> |
- <MATCH: "match"> |
+ <MAX: "max"> |
+ <MIN: "min"> |
<RELU: "relu"> |
<SIGMOID: "sigmoid"> |
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
@@ -310,28 +314,53 @@ FunctionNode scalarFunction() :
ExpressionNode tensorFunction() :
{
- ExpressionNode tensor1, tensor2;
- String dimension = null;
- TensorAddress address = null;
+ ExpressionNode tensorExpression;
}
{
- (
- <SUM> <LBRACE> tensor1 = expression() ( <COMMA> dimension = identifier() )? <RBRACE>
- { return new TensorSumNode(tensor1, Optional.ofNullable(dimension)); }
- ) |
- (
- <MATCH> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <RBRACE>
- { return new TensorMatchNode(tensor1, tensor2); }
- )
+ ( tensorExpression = tensorPrimitiveReduce() | tensorExpression = tensorReduce() )
+ { return tensorExpression; }
+}
+
+ExpressionNode tensorPrimitiveReduce() :
+{
+ ExpressionNode tensor1;
+ ReduceFunction.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ <REDUCE> <LBRACE> tensor1 = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorReduceNode(tensor1, aggregator, dimensions); }
+}
+
+ExpressionNode tensorReduce() :
+{
+ ExpressionNode tensor1;
+ ReduceFunction.Aggregator aggregator;
+ List<String> dimensions = null;
+}
+{
+ aggregator = tensorReduceAggregator()
+ <LBRACE> tensor1 = expression() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorReduceNode(tensor1, aggregator, dimensions); }
+}
+
+ReduceFunction.Aggregator tensorReduceAggregator() :
+{
+}
+{
+ ( <AVG> | <COUNT> | <PROD> | <SUM> | <MAX> | <MIN> )
+ { return ReduceFunction.Aggregator.valueOf(token.image); }
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
+ ReduceFunction.Aggregator aggregator;
}
{
- ( <SUM> | <MATCH> )
- { return token.image; }
+ ( <REDUCE> { return token.image; } )
+ |
+ ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
Function unaryFunctionName() : { }
@@ -413,6 +442,16 @@ String tag() :
<INTEGER> { return token.image; }
}
+List<String> tagCommaLeadingList() :
+{
+ List<String> list = new ArrayList<String>();
+ String element;
+}
+{
+ ( <COMMA> element = tag() { list.add(element); } ) *
+ { return list; }
+}
+
ConstantNode constantPrimitive() :
{
String sign = "";