diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 13:25:02 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 13:25:02 +0100 |
commit | 015cedfb6dbd15dec60602ba3082198502d1c5d9 (patch) | |
tree | 2b546af79cc157e12b4300e358e8869fe003f409 /searchlib/src/main | |
parent | 2b4e552165c18544e1ae702175d632e1e39a6e46 (diff) |
Parse lambda
Diffstat (limited to 'searchlib/src/main')
4 files changed, 193 insertions, 12 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java new file mode 100644 index 00000000000..7ac763ef4c4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -0,0 +1,82 @@ +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.function.DoubleUnaryOperator; + +/** + * A free, parametrized function + * + * @author bratseth + */ +public class LambdaFunctionNode extends CompositeNode { + + private final ImmutableList<String> arguments; + private final ExpressionNode functionExpression; + + public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { + // TODO: Verify that the function only accesses the arguments in mapperVariables + this.arguments = ImmutableList.copyOf(arguments); + this.functionExpression = functionExpression; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(functionExpression); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if ( children.size() != 1) + throw new IllegalArgumentException("A lambda function must have a single child expression"); + return new LambdaFunctionNode(arguments, children.get(0)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")"; + } + + private String commaSeparated(List<String> list) { + StringBuilder b = new StringBuilder(); + for (String element : list) + b.append(", ").append(element); + return b.toString(); + } + + /** Evaluate this in a context which must have the arguments bound */ + @Override + public Value evaluate(Context context) { + return functionExpression.evaluate(context); + } + + /** + * Returns this as a double unary operator + * + * @throws IllegalStateException if this does not have exactly one argument + */ + public DoubleUnaryOperator asDoubleUnaryOperator() { + if (arguments.size() != 1) + throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + + "Must have one argument " + " but has " + arguments); + return new DoubleUnaryLambda(); + } + + private class DoubleUnaryLambda implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { + MapContext context = new MapContext(); + context.put(arguments.get(0), operand); + return evaluate(context).asDouble(); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java new file mode 100644 index 00000000000..0cb0da150b4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java @@ -0,0 +1,58 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Deque; +import java.util.List; + +/** + * A node which maps the values of a tensor + * + * @author bratseth + */ + @Beta +public class TensorMapNode extends CompositeNode { + + /** The tensor to aggregate over */ + private final ExpressionNode argument; + + private final LambdaFunctionNode doubleMapper; + + public TensorMapNode(ExpressionNode argument, LambdaFunctionNode doubleMapper) { + this.argument = argument; + this.doubleMapper = doubleMapper; + } + + @Override + public List<ExpressionNode> children() { + return ImmutableList.of(argument, doubleMapper); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if (children.size() != 2) + throw new IllegalArgumentException("A tensor map node must have one tensor and one mapper"); + return new TensorMapNode(children.get(0), (LambdaFunctionNode)children.get(1)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "map(" + argument.toString(context, path, parent) + ", " + doubleMapper.toString() + ")"; + } + + @Override + public Value evaluate(Context context) { + Value argumentValue = argument.evaluate(context); + if ( ! ( argumentValue instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to map '" + argument + "', " + + "but this returns " + argumentValue + ", not a tensor"); + TensorValue tensorArgument = (TensorValue)argumentValue; + return new TensorValue(tensorArgument.asTensor().map(doubleMapper.asDoubleUnaryOperator())); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java index 4e4095cb86e..65a1802c72d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java @@ -12,7 +12,7 @@ import java.util.Deque; import java.util.List; /** - * A node which sums over all cells in the argument tensor + * A node which performs a dimension reduction over a tensor * * @author bratseth */ diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index a800028d00b..2cebbdf7d75 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -100,7 +100,10 @@ TOKEN : <FMOD: "fmod"> | <ISNAN: "isNan"> | <IN: "in"> | + <F: "f"> | + <MAP: "map"> | <REDUCE: "reduce"> | + <JOIN: "join"> | <AVG: "avg" > | <COUNT: "count"> | <PROD: "prod"> | @@ -317,31 +320,55 @@ ExpressionNode tensorFunction() : ExpressionNode tensorExpression; } { - ( tensorExpression = tensorPrimitiveReduce() | tensorExpression = tensorReduce() ) + ( + tensorExpression = tensorMap() | + tensorExpression = tensorReduce() | + tensorExpression = tensorReduceComposites() + ) { return tensorExpression; } } -ExpressionNode tensorPrimitiveReduce() : +ExpressionNode tensorMap() : +{ + ExpressionNode tensor; + LambdaFunctionNode doubleMapper; +} +{ + <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> + { return new TensorMapNode(tensor, doubleMapper); } +} + +LambdaFunctionNode lambdaFunction() : +{ + List<String> variables; + ExpressionNode functionExpression; +} +{ + ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) + { return new LambdaFunctionNode(variables, functionExpression); } +} + +ExpressionNode tensorReduce() : { - ExpressionNode tensor1; + ExpressionNode tensor; ReduceFunction.Aggregator aggregator; List<String> dimensions = null; } { - <REDUCE> <LBRACE> tensor1 = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorReduceNode(tensor1, aggregator, dimensions); } + <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorReduceNode(tensor, aggregator, dimensions); } } -ExpressionNode tensorReduce() : +ExpressionNode tensorReduceComposites() : { - ExpressionNode tensor1; + ExpressionNode tensor; ReduceFunction.Aggregator aggregator; List<String> dimensions = null; } { aggregator = tensorReduceAggregator() - <LBRACE> tensor1 = expression() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorReduceNode(tensor1, aggregator, dimensions); } + <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorReduceNode(tensor, aggregator, dimensions); } } ReduceFunction.Aggregator tensorReduceAggregator() : @@ -358,8 +385,10 @@ String tensorFunctionName() : ReduceFunction.Aggregator aggregator; } { - ( <REDUCE> { return token.image; } ) - | + ( <F> { return token.image; } ) | + ( <MAP> { return token.image; } ) | + ( <REDUCE> { return token.image; } ) | + ( <JOIN> { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } @@ -432,6 +461,18 @@ String identifier() : <IDENTIFIER> { return token.image; } } +List<String> identifierList() : +{ + List<String> list = new ArrayList<String>(); + String element; +} +{ + ( element = identifier() { list.add(element); } )? + ( <COMMA> element = identifier() { list.add(element); } ) * + { return list; } +} + + // An identifier or integer String tag() : { |