diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 14:08:35 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-11-23 14:08:35 +0100 |
commit | f65c80a1fb5fdc285ce0db63b3b1f039f5201505 (patch) | |
tree | c767fe82963a3c29276078e0d8f588a126a7cca6 /searchlib/src/main | |
parent | 015cedfb6dbd15dec60602ba3082198502d1c5d9 (diff) |
Parse join
Diffstat (limited to 'searchlib/src/main')
3 files changed, 124 insertions, 17 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 7ac763ef4c4..593fa4bc45e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import java.util.Collections; import java.util.Deque; import java.util.List; +import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; /** @@ -46,7 +47,9 @@ public class LambdaFunctionNode extends CompositeNode { private String commaSeparated(List<String> list) { StringBuilder b = new StringBuilder(); for (String element : list) - b.append(", ").append(element); + b.append(element).append(", "); + if (b.length() > 0) + b.setLength(b.length() -1); return b.toString(); } @@ -59,21 +62,48 @@ public class LambdaFunctionNode extends CompositeNode { /** * Returns this as a double unary operator * - * @throws IllegalStateException if this does not have exactly one argument + * @throws IllegalStateException if this has more than one argument */ public DoubleUnaryOperator asDoubleUnaryOperator() { - if (arguments.size() != 1) + if (arguments.size() > 1) throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " + - "Must have one argument " + " but has " + arguments); + "Must have at most one argument " + " but has " + arguments); return new DoubleUnaryLambda(); } - + + /** + * Returns this as a double binary operator + * + * @throws IllegalStateException if this has more than two arguments + */ + public DoubleBinaryOperator asDoubleBinaryOperator() { + if (arguments.size() > 2) + throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " + + "Must have at most two argument " + " but has " + arguments); + return new DoubleBinaryLambda(); + } + private class DoubleUnaryLambda implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { MapContext context = new MapContext(); - context.put(arguments.get(0), operand); + if (arguments.size() > 0) + context.put(arguments.get(0), operand); + return evaluate(context).asDouble(); + } + + } + + private class DoubleBinaryLambda implements DoubleBinaryOperator { + + @Override + public double applyAsDouble(double left, double right) { + MapContext context = new MapContext(); + if (arguments.size() > 0) + context.put(arguments.get(0), left); + if (arguments.size() > 1) + context.put(arguments.get(1), right); return evaluate(context).asDouble(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java new file mode 100644 index 00000000000..21455113578 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java @@ -0,0 +1,66 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; + +import java.util.Deque; +import java.util.List; + +/** + * A node which joins two tensors + * + * @author bratseth + */ + @Beta +public class TensorJoinNode extends CompositeNode { + + /** The tensor to aggregate over */ + private final ExpressionNode argument1, argument2; + + private final LambdaFunctionNode doubleJoiner; + + public TensorJoinNode(ExpressionNode argument1, ExpressionNode argument2, LambdaFunctionNode doubleJoiner) { + this.argument1 = argument1; + this.argument2 = argument2; + this.doubleJoiner = doubleJoiner; + } + + @Override + public List<ExpressionNode> children() { + return ImmutableList.of(argument1, argument2, doubleJoiner); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if (children.size() != 3) + throw new IllegalArgumentException("A tensor join node must have two tensors and one joiner"); + return new TensorJoinNode(children.get(0), children.get(1), (LambdaFunctionNode)children.get(2)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "join(" + argument1.toString(context, path, parent) + ", " + + argument2.toString(context, path, parent) + ", " + + doubleJoiner.toString() + ")"; + } + + @Override + public Value evaluate(Context context) { + Tensor argument1Value = asTensor(argument1.evaluate(context), argument1); + Tensor argument2Value = asTensor(argument2.evaluate(context), argument2); + return new TensorValue(argument1Value.join(argument2Value, doubleJoiner.asDoubleBinaryOperator())); + } + + private Tensor asTensor(Value value, ExpressionNode producingNode) { + if ( ! ( value instanceof TensorValue)) + throw new IllegalArgumentException("Attempted to join '" + producingNode + "', " + + "but this returns " + value + ", not a tensor"); + return ((TensorValue)value).asTensor(); + } + +} diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 2cebbdf7d75..5a5d916f7e7 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -323,7 +323,8 @@ ExpressionNode tensorFunction() : ( tensorExpression = tensorMap() | tensorExpression = tensorReduce() | - tensorExpression = tensorReduceComposites() + tensorExpression = tensorReduceComposites() | + tensorExpression = tensorJoin() ) { return tensorExpression; } } @@ -338,16 +339,6 @@ ExpressionNode tensorMap() : { return new TensorMapNode(tensor, doubleMapper); } } -LambdaFunctionNode lambdaFunction() : -{ - List<String> variables; - ExpressionNode functionExpression; -} -{ - ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) - { return new LambdaFunctionNode(variables, functionExpression); } -} - ExpressionNode tensorReduce() : { ExpressionNode tensor; @@ -371,6 +362,26 @@ ExpressionNode tensorReduceComposites() : { return new TensorReduceNode(tensor, aggregator, dimensions); } } +ExpressionNode tensorJoin() : +{ + ExpressionNode tensor1, tensor2; + LambdaFunctionNode doubleJoiner; +} +{ + <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> + { return new TensorJoinNode(tensor1, tensor2, doubleJoiner); } +} + +LambdaFunctionNode lambdaFunction() : +{ + List<String> variables; + ExpressionNode functionExpression; +} +{ + ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> ) + { return new LambdaFunctionNode(variables, functionExpression); } +} + ReduceFunction.Aggregator tensorReduceAggregator() : { } |