summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-23 14:08:35 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-23 14:08:35 +0100
commitf65c80a1fb5fdc285ce0db63b3b1f039f5201505 (patch)
treec767fe82963a3c29276078e0d8f588a126a7cca6 /searchlib/src/main
parent015cedfb6dbd15dec60602ba3082198502d1c5d9 (diff)
Parse join
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java42
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java66
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj33
3 files changed, 124 insertions, 17 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 7ac763ef4c4..593fa4bc45e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
/**
@@ -46,7 +47,9 @@ public class LambdaFunctionNode extends CompositeNode {
private String commaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
- b.append(", ").append(element);
+ b.append(element).append(", ");
+ if (b.length() > 0)
+ b.setLength(b.length() -1);
return b.toString();
}
@@ -59,21 +62,48 @@ public class LambdaFunctionNode extends CompositeNode {
/**
* Returns this as a double unary operator
*
- * @throws IllegalStateException if this does not have exactly one argument
+ * @throws IllegalStateException if this has more than one argument
*/
public DoubleUnaryOperator asDoubleUnaryOperator() {
- if (arguments.size() != 1)
+ if (arguments.size() > 1)
throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
- "Must have one argument " + " but has " + arguments);
+ "Must have at most one argument " + " but has " + arguments);
return new DoubleUnaryLambda();
}
-
+
+ /**
+ * Returns this as a double binary operator
+ *
+ * @throws IllegalStateException if this has more than two arguments
+ */
+ public DoubleBinaryOperator asDoubleBinaryOperator() {
+ if (arguments.size() > 2)
+ throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
+ "Must have at most two argument " + " but has " + arguments);
+ return new DoubleBinaryLambda();
+ }
+
private class DoubleUnaryLambda implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) {
MapContext context = new MapContext();
- context.put(arguments.get(0), operand);
+ if (arguments.size() > 0)
+ context.put(arguments.get(0), operand);
+ return evaluate(context).asDouble();
+ }
+
+ }
+
+ private class DoubleBinaryLambda implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) {
+ MapContext context = new MapContext();
+ if (arguments.size() > 0)
+ context.put(arguments.get(0), left);
+ if (arguments.size() > 1)
+ context.put(arguments.get(1), right);
return evaluate(context).asDouble();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java
new file mode 100644
index 00000000000..21455113578
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorJoinNode.java
@@ -0,0 +1,66 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
+
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * A node which joins two tensors
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorJoinNode extends CompositeNode {
+
+ /** The tensor to aggregate over */
+ private final ExpressionNode argument1, argument2;
+
+ private final LambdaFunctionNode doubleJoiner;
+
+ public TensorJoinNode(ExpressionNode argument1, ExpressionNode argument2, LambdaFunctionNode doubleJoiner) {
+ this.argument1 = argument1;
+ this.argument2 = argument2;
+ this.doubleJoiner = doubleJoiner;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return ImmutableList.of(argument1, argument2, doubleJoiner);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 3)
+ throw new IllegalArgumentException("A tensor join node must have two tensors and one joiner");
+ return new TensorJoinNode(children.get(0), children.get(1), (LambdaFunctionNode)children.get(2));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "join(" + argument1.toString(context, path, parent) + ", " +
+ argument2.toString(context, path, parent) + ", " +
+ doubleJoiner.toString() + ")";
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Tensor argument1Value = asTensor(argument1.evaluate(context), argument1);
+ Tensor argument2Value = asTensor(argument2.evaluate(context), argument2);
+ return new TensorValue(argument1Value.join(argument2Value, doubleJoiner.asDoubleBinaryOperator()));
+ }
+
+ private Tensor asTensor(Value value, ExpressionNode producingNode) {
+ if ( ! ( value instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to join '" + producingNode + "', " +
+ "but this returns " + value + ", not a tensor");
+ return ((TensorValue)value).asTensor();
+ }
+
+}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 2cebbdf7d75..5a5d916f7e7 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -323,7 +323,8 @@ ExpressionNode tensorFunction() :
(
tensorExpression = tensorMap() |
tensorExpression = tensorReduce() |
- tensorExpression = tensorReduceComposites()
+ tensorExpression = tensorReduceComposites() |
+ tensorExpression = tensorJoin()
)
{ return tensorExpression; }
}
@@ -338,16 +339,6 @@ ExpressionNode tensorMap() :
{ return new TensorMapNode(tensor, doubleMapper); }
}
-LambdaFunctionNode lambdaFunction() :
-{
- List<String> variables;
- ExpressionNode functionExpression;
-}
-{
- ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
- { return new LambdaFunctionNode(variables, functionExpression); }
-}
-
ExpressionNode tensorReduce() :
{
ExpressionNode tensor;
@@ -371,6 +362,26 @@ ExpressionNode tensorReduceComposites() :
{ return new TensorReduceNode(tensor, aggregator, dimensions); }
}
+ExpressionNode tensorJoin() :
+{
+ ExpressionNode tensor1, tensor2;
+ LambdaFunctionNode doubleJoiner;
+}
+{
+ <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
+ { return new TensorJoinNode(tensor1, tensor2, doubleJoiner); }
+}
+
+LambdaFunctionNode lambdaFunction() :
+{
+ List<String> variables;
+ ExpressionNode functionExpression;
+}
+{
+ ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
+ { return new LambdaFunctionNode(variables, functionExpression); }
+}
+
ReduceFunction.Aggregator tensorReduceAggregator() :
{
}