diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 13:25:02 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 13:25:02 +0100 |
commit | 015cedfb6dbd15dec60602ba3082198502d1c5d9 (patch) | |
tree | 2b546af79cc157e12b4300e358e8869fe003f409 | |
parent | 2b4e552165c18544e1ae702175d632e1e39a6e46 (diff) |
Parse lambda
5 files changed, 217 insertions, 30 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java new file mode 100644 index 00000000000..7ac763ef4c4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -0,0 +1,82 @@ +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.function.DoubleUnaryOperator; + +/** + * A free, parametrized function + * + * @author bratseth + */ +public class LambdaFunctionNode extends CompositeNode { + + private final ImmutableList<String> arguments; + private final ExpressionNode functionExpression; + + public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { + // TODO: Verify that the function only accesses the arguments in mapperVariables + this.arguments = ImmutableList.copyOf(arguments); + this.functionExpression = functionExpression; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(functionExpression); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if ( children.size() != 1) + throw new IllegalArgumentException("A lambda function must have a single child expression"); + return new LambdaFunctionNode(arguments, children.get(0)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")"; + } + + private String commaSeparated(List<String> list) { + StringBuilder b = new StringBuilder(); + for (String element : list) + b.append(", ").append(element); + return b.toString(); + } + + /** Evaluate this in a context which must have the arguments bound */ + @Override + public Value evaluate(Context context) { + return functionExpression.evaluate(context); + } + + /** + * Returns this as a double unary operator + * + * @throws IllegalStateException if this does not have exactly one argument + */ + public DoubleUnaryOperator asDoubleUnaryOperator() { + if (arguments.size() != 1) + throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + + "Must have one argument " + " but has " + arguments); + return new DoubleUnaryLambda(); + } + + private class DoubleUnaryLambda implements DoubleUnaryOperator { + + @Override + public double applyAsDouble(double operand) { + MapContext context = new MapContext(); + context.put(arguments.get(0), operand); + return evaluate(context).asDouble(); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java new file mode 100644 index 00000000000..0cb0da150b4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java @@ -0,0 +1,58 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Deque; +import java.util.List; + +/** + * A node which maps the values of a tensor + * + * @author bratseth + */ + @Beta +public class TensorMapNode extends CompositeNode { + + /** The tensor to aggregate over */ + private final ExpressionNode argument; + + private final LambdaFunctionNode doubleMapper; + + public TensorMapNode(ExpressionNode argument, LambdaFunctionNode doubleMapper) { + this.argument = argument; + this.doubleMapper = doubleMapper; + } + + @Override + public List<ExpressionNode> children() { + return ImmutableList.of(argument, doubleMapper); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if (children.size() != 2) + throw new IllegalArgumentException("A tensor map node must have one tensor and one mapper"); + return new TensorMapNode(children.get(0), (LambdaFunctionNode)children.get(1)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "map(" + argument.toString(context, path, parent) + ", " + doubleMapper.toString() + ")"; + } + + @Override + public Value evaluate(Context context) { + Value argumentValue = argument.evaluate(context); + if ( ! ( argumentValue instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to map '" + argument + "', " + + "but this returns " + argumentValue + ", not a tensor"); + TensorValue tensorArgument = (TensorValue)argumentValue; + return new TensorValue(tensorArgument.asTensor().map(doubleMapper.asDoubleUnaryOperator())); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java index 4e4095cb86e..65a1802c72d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java @@ -12,7 +12,7 @@ import java.util.Deque; import java.util.List; /** - * A node which sums over all cells in the argument tensor + * A node which performs a dimension reduction over a tensor * * @author bratseth */ diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index a800028d00b..2cebbdf7d75 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -100,7 +100,10 @@ TOKEN : <FMOD: "fmod"> | <ISNAN: "isNan"> | <IN: "in"> | + <F: "f"> | + <MAP: "map"> | <REDUCE: "reduce"> | + <JOIN: "join"> | <AVG: "avg" > | <COUNT: "count"> | <PROD: "prod"> | @@ -317,31 +320,55 @@ ExpressionNode tensorFunction() : ExpressionNode tensorExpression; } { - ( tensorExpression = tensorPrimitiveReduce() | tensorExpression = tensorReduce() ) + ( + tensorExpression = tensorMap() | + tensorExpression = tensorReduce() | + tensorExpression = tensorReduceComposites() + ) { return tensorExpression; } } -ExpressionNode tensorPrimitiveReduce() : +ExpressionNode tensorMap() : +{ + ExpressionNode tensor; + LambdaFunctionNode doubleMapper; +} +{ + <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> + { return new TensorMapNode(tensor, doubleMapper); } +} + +LambdaFunctionNode lambdaFunction() : +{ + List<String> variables; + ExpressionNode functionExpression; +} +{ + ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) + { return new LambdaFunctionNode(variables, functionExpression); } +} + +ExpressionNode tensorReduce() : { - ExpressionNode tensor1; + ExpressionNode tensor; ReduceFunction.Aggregator aggregator; List<String> dimensions = null; } { - <REDUCE> <LBRACE> tensor1 = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorReduceNode(tensor1, aggregator, dimensions); } + <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorReduceNode(tensor, aggregator, dimensions); } } -ExpressionNode tensorReduce() : +ExpressionNode tensorReduceComposites() : { - ExpressionNode tensor1; + ExpressionNode tensor; ReduceFunction.Aggregator aggregator; List<String> dimensions = null; } { aggregator = tensorReduceAggregator() - <LBRACE> tensor1 = expression() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorReduceNode(tensor1, aggregator, dimensions); } + <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> + { return new TensorReduceNode(tensor, aggregator, dimensions); } } ReduceFunction.Aggregator tensorReduceAggregator() : @@ -358,8 +385,10 @@ String tensorFunctionName() : ReduceFunction.Aggregator aggregator; } { - ( <REDUCE> { return token.image; } ) - | + ( <F> { return token.image; } ) | + ( <MAP> { return token.image; } ) | + ( <REDUCE> { return token.image; } ) | + ( <JOIN> { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } @@ -432,6 +461,18 @@ String identifier() : <IDENTIFIER> { return token.image; } } +List<String> identifierList() : +{ + List<String> list = new ArrayList<String>(); + String element; +} +{ + ( element = identifier() { list.add(element); } )? + ( <COMMA> element = identifier() { list.add(element); } ) * + { return list; } +} + + // An identifier or integer String tag() : { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 50e79b301f4..c793e203b23 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -100,9 +100,26 @@ public class EvaluationTestCase extends junit.framework.TestCase { @Test public void testTensorEvaluation() { - assertEvaluates("{}", "tensor0", "{}"); // empty + assertEvaluates("{}", "tensor0", "{}"); + + // tensor map + assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", + "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + // tensor map derivatives + assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", + "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }", + "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }", + "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); + assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }", + "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); + assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }"); // tensor reduce + assertEvaluates("{ {}:16 }", + "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + // reduce composites assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0"); assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0"); assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }"); @@ -114,23 +131,6 @@ public class EvaluationTestCase extends junit.framework.TestCase { assertEvaluates("{ {}:16 }", "sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); - // tensor map - assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", - "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", - "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", - "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", - "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }", - "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }", - "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); - assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }", - "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); - assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }"); - // tensor join assertEvaluates("{ }", "tensor0 * tensor0", "{}"); assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )", @@ -156,6 +156,12 @@ public class EvaluationTestCase extends junit.framework.TestCase { assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }", "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }"); + assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", + "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", + "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", + "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); // Combined assertEvaluates(String.valueOf(7.5 + 45 + 1.7), |