summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-23 13:25:02 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-23 13:25:02 +0100
commit015cedfb6dbd15dec60602ba3082198502d1c5d9 (patch)
tree2b546af79cc157e12b4300e358e8869fe003f409 /searchlib/src/main
parent2b4e552165c18544e1ae702175d632e1e39a6e46 (diff)
Parse lambda
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java82
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java58
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj63
4 files changed, 193 insertions, 12 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
new file mode 100644
index 00000000000..7ac763ef4c4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -0,0 +1,82 @@
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * A free, parametrized function
+ *
+ * @author bratseth
+ */
+public class LambdaFunctionNode extends CompositeNode {
+
+ private final ImmutableList<String> arguments;
+ private final ExpressionNode functionExpression;
+
+ public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
+ // TODO: Verify that the function only accesses the arguments in mapperVariables
+ this.arguments = ImmutableList.copyOf(arguments);
+ this.functionExpression = functionExpression;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(functionExpression);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if ( children.size() != 1)
+ throw new IllegalArgumentException("A lambda function must have a single child expression");
+ return new LambdaFunctionNode(arguments, children.get(0));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")";
+ }
+
+ private String commaSeparated(List<String> list) {
+ StringBuilder b = new StringBuilder();
+ for (String element : list)
+ b.append(", ").append(element);
+ return b.toString();
+ }
+
+ /** Evaluate this in a context which must have the arguments bound */
+ @Override
+ public Value evaluate(Context context) {
+ return functionExpression.evaluate(context);
+ }
+
+ /**
+ * Returns this as a double unary operator
+ *
+ * @throws IllegalStateException if this does not have exactly one argument
+ */
+ public DoubleUnaryOperator asDoubleUnaryOperator() {
+ if (arguments.size() != 1)
+ throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
+ "Must have one argument " + " but has " + arguments);
+ return new DoubleUnaryLambda();
+ }
+
+ private class DoubleUnaryLambda implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) {
+ MapContext context = new MapContext();
+ context.put(arguments.get(0), operand);
+ return evaluate(context).asDouble();
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java
new file mode 100644
index 00000000000..0cb0da150b4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java
@@ -0,0 +1,58 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * A node which maps the values of a tensor
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorMapNode extends CompositeNode {
+
+ /** The tensor to aggregate over */
+ private final ExpressionNode argument;
+
+ private final LambdaFunctionNode doubleMapper;
+
+ public TensorMapNode(ExpressionNode argument, LambdaFunctionNode doubleMapper) {
+ this.argument = argument;
+ this.doubleMapper = doubleMapper;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return ImmutableList.of(argument, doubleMapper);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 2)
+ throw new IllegalArgumentException("A tensor map node must have one tensor and one mapper");
+ return new TensorMapNode(children.get(0), (LambdaFunctionNode)children.get(1));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "map(" + argument.toString(context, path, parent) + ", " + doubleMapper.toString() + ")";
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Value argumentValue = argument.evaluate(context);
+ if ( ! ( argumentValue instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to map '" + argument + "', " +
+ "but this returns " + argumentValue + ", not a tensor");
+ TensorValue tensorArgument = (TensorValue)argumentValue;
+ return new TensorValue(tensorArgument.asTensor().map(doubleMapper.asDoubleUnaryOperator()));
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
index 4e4095cb86e..65a1802c72d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
@@ -12,7 +12,7 @@ import java.util.Deque;
import java.util.List;
/**
- * A node which sums over all cells in the argument tensor
+ * A node which performs a dimension reduction over a tensor
*
* @author bratseth
*/
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index a800028d00b..2cebbdf7d75 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -100,7 +100,10 @@ TOKEN :
<FMOD: "fmod"> |
<ISNAN: "isNan"> |
<IN: "in"> |
+ <F: "f"> |
+ <MAP: "map"> |
<REDUCE: "reduce"> |
+ <JOIN: "join"> |
<AVG: "avg" > |
<COUNT: "count"> |
<PROD: "prod"> |
@@ -317,31 +320,55 @@ ExpressionNode tensorFunction() :
ExpressionNode tensorExpression;
}
{
- ( tensorExpression = tensorPrimitiveReduce() | tensorExpression = tensorReduce() )
+ (
+ tensorExpression = tensorMap() |
+ tensorExpression = tensorReduce() |
+ tensorExpression = tensorReduceComposites()
+ )
{ return tensorExpression; }
}
-ExpressionNode tensorPrimitiveReduce() :
+ExpressionNode tensorMap() :
+{
+ ExpressionNode tensor;
+ LambdaFunctionNode doubleMapper;
+}
+{
+ <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
+ { return new TensorMapNode(tensor, doubleMapper); }
+}
+
+LambdaFunctionNode lambdaFunction() :
+{
+ List<String> variables;
+ ExpressionNode functionExpression;
+}
+{
+ ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
+ { return new LambdaFunctionNode(variables, functionExpression); }
+}
+
+ExpressionNode tensorReduce() :
{
- ExpressionNode tensor1;
+ ExpressionNode tensor;
ReduceFunction.Aggregator aggregator;
List<String> dimensions = null;
}
{
- <REDUCE> <LBRACE> tensor1 = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorReduceNode(tensor1, aggregator, dimensions); }
+ <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorReduceNode(tensor, aggregator, dimensions); }
}
-ExpressionNode tensorReduce() :
+ExpressionNode tensorReduceComposites() :
{
- ExpressionNode tensor1;
+ ExpressionNode tensor;
ReduceFunction.Aggregator aggregator;
List<String> dimensions = null;
}
{
aggregator = tensorReduceAggregator()
- <LBRACE> tensor1 = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorReduceNode(tensor1, aggregator, dimensions); }
+ <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorReduceNode(tensor, aggregator, dimensions); }
}
ReduceFunction.Aggregator tensorReduceAggregator() :
@@ -358,8 +385,10 @@ String tensorFunctionName() :
ReduceFunction.Aggregator aggregator;
}
{
- ( <REDUCE> { return token.image; } )
- |
+ ( <F> { return token.image; } ) |
+ ( <MAP> { return token.image; } ) |
+ ( <REDUCE> { return token.image; } ) |
+ ( <JOIN> { return token.image; } ) |
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
@@ -432,6 +461,18 @@ String identifier() :
<IDENTIFIER> { return token.image; }
}
+List<String> identifierList() :
+{
+ List<String> list = new ArrayList<String>();
+ String element;
+}
+{
+ ( element = identifier() { list.add(element); } )?
+ ( <COMMA> element = identifier() { list.add(element); } ) *
+ { return list; }
+}
+
+
// An identifier or integer
String tag() :
{