summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java82
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java58
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj63
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java42
5 files changed, 217 insertions, 30 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
new file mode 100644
index 00000000000..7ac763ef4c4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -0,0 +1,82 @@
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.DoubleUnaryOperator;
+
+/**
+ * A free, parametrized function
+ *
+ * @author bratseth
+ */
+public class LambdaFunctionNode extends CompositeNode {
+
+ private final ImmutableList<String> arguments;
+ private final ExpressionNode functionExpression;
+
+ public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
+ // TODO: Verify that the function only accesses the arguments in mapperVariables
+ this.arguments = ImmutableList.copyOf(arguments);
+ this.functionExpression = functionExpression;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(functionExpression);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if ( children.size() != 1)
+ throw new IllegalArgumentException("A lambda function must have a single child expression");
+ return new LambdaFunctionNode(arguments, children.get(0));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return ("f(" + commaSeparated(arguments) + ")(" + functionExpression.toString(context, path, this)) + ")";
+ }
+
+ private String commaSeparated(List<String> list) {
+ StringBuilder b = new StringBuilder();
+ for (String element : list)
+ b.append(", ").append(element);
+ return b.toString();
+ }
+
+ /** Evaluate this in a context which must have the arguments bound */
+ @Override
+ public Value evaluate(Context context) {
+ return functionExpression.evaluate(context);
+ }
+
+ /**
+ * Returns this as a double unary operator
+ *
+ * @throws IllegalStateException if this does not have exactly one argument
+ */
+ public DoubleUnaryOperator asDoubleUnaryOperator() {
+ if (arguments.size() != 1)
+ throw new IllegalStateException("Cannot apply " + this + " as a DoubleUnaryOperator: " +
+ "Must have one argument " + " but has " + arguments);
+ return new DoubleUnaryLambda();
+ }
+
+ private class DoubleUnaryLambda implements DoubleUnaryOperator {
+
+ @Override
+ public double applyAsDouble(double operand) {
+ MapContext context = new MapContext();
+ context.put(arguments.get(0), operand);
+ return evaluate(context).asDouble();
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java
new file mode 100644
index 00000000000..0cb0da150b4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorMapNode.java
@@ -0,0 +1,58 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * A node which maps the values of a tensor
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorMapNode extends CompositeNode {
+
+ /** The tensor to aggregate over */
+ private final ExpressionNode argument;
+
+ private final LambdaFunctionNode doubleMapper;
+
+ public TensorMapNode(ExpressionNode argument, LambdaFunctionNode doubleMapper) {
+ this.argument = argument;
+ this.doubleMapper = doubleMapper;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return ImmutableList.of(argument, doubleMapper);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 2)
+ throw new IllegalArgumentException("A tensor map node must have one tensor and one mapper");
+ return new TensorMapNode(children.get(0), (LambdaFunctionNode)children.get(1));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "map(" + argument.toString(context, path, parent) + ", " + doubleMapper.toString() + ")";
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Value argumentValue = argument.evaluate(context);
+ if ( ! ( argumentValue instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to map '" + argument + "', " +
+ "but this returns " + argumentValue + ", not a tensor");
+ TensorValue tensorArgument = (TensorValue)argumentValue;
+ return new TensorValue(tensorArgument.asTensor().map(doubleMapper.asDoubleUnaryOperator()));
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
index 4e4095cb86e..65a1802c72d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
@@ -12,7 +12,7 @@ import java.util.Deque;
import java.util.List;
/**
- * A node which sums over all cells in the argument tensor
+ * A node which performs a dimension reduction over a tensor
*
* @author bratseth
*/
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index a800028d00b..2cebbdf7d75 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -100,7 +100,10 @@ TOKEN :
<FMOD: "fmod"> |
<ISNAN: "isNan"> |
<IN: "in"> |
+ <F: "f"> |
+ <MAP: "map"> |
<REDUCE: "reduce"> |
+ <JOIN: "join"> |
<AVG: "avg" > |
<COUNT: "count"> |
<PROD: "prod"> |
@@ -317,31 +320,55 @@ ExpressionNode tensorFunction() :
ExpressionNode tensorExpression;
}
{
- ( tensorExpression = tensorPrimitiveReduce() | tensorExpression = tensorReduce() )
+ (
+ tensorExpression = tensorMap() |
+ tensorExpression = tensorReduce() |
+ tensorExpression = tensorReduceComposites()
+ )
{ return tensorExpression; }
}
-ExpressionNode tensorPrimitiveReduce() :
+ExpressionNode tensorMap() :
+{
+ ExpressionNode tensor;
+ LambdaFunctionNode doubleMapper;
+}
+{
+ <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
+ { return new TensorMapNode(tensor, doubleMapper); }
+}
+
+LambdaFunctionNode lambdaFunction() :
+{
+ List<String> variables;
+ ExpressionNode functionExpression;
+}
+{
+ ( <F> <LBRACE> variables = identifierList() <RBRACE> <LBRACE> functionExpression = expression() <RBRACE> )
+ { return new LambdaFunctionNode(variables, functionExpression); }
+}
+
+ExpressionNode tensorReduce() :
{
- ExpressionNode tensor1;
+ ExpressionNode tensor;
ReduceFunction.Aggregator aggregator;
List<String> dimensions = null;
}
{
- <REDUCE> <LBRACE> tensor1 = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorReduceNode(tensor1, aggregator, dimensions); }
+ <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorReduceNode(tensor, aggregator, dimensions); }
}
-ExpressionNode tensorReduce() :
+ExpressionNode tensorReduceComposites() :
{
- ExpressionNode tensor1;
+ ExpressionNode tensor;
ReduceFunction.Aggregator aggregator;
List<String> dimensions = null;
}
{
aggregator = tensorReduceAggregator()
- <LBRACE> tensor1 = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorReduceNode(tensor1, aggregator, dimensions); }
+ <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
+ { return new TensorReduceNode(tensor, aggregator, dimensions); }
}
ReduceFunction.Aggregator tensorReduceAggregator() :
@@ -358,8 +385,10 @@ String tensorFunctionName() :
ReduceFunction.Aggregator aggregator;
}
{
- ( <REDUCE> { return token.image; } )
- |
+ ( <F> { return token.image; } ) |
+ ( <MAP> { return token.image; } ) |
+ ( <REDUCE> { return token.image; } ) |
+ ( <JOIN> { return token.image; } ) |
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
@@ -432,6 +461,18 @@ String identifier() :
<IDENTIFIER> { return token.image; }
}
+List<String> identifierList() :
+{
+ List<String> list = new ArrayList<String>();
+ String element;
+}
+{
+ ( element = identifier() { list.add(element); } )?
+ ( <COMMA> element = identifier() { list.add(element); } ) *
+ { return list; }
+}
+
+
// An identifier or integer
String tag() :
{
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 50e79b301f4..c793e203b23 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -100,9 +100,26 @@ public class EvaluationTestCase extends junit.framework.TestCase {
@Test
public void testTensorEvaluation() {
- assertEvaluates("{}", "tensor0", "{}"); // empty
+ assertEvaluates("{}", "tensor0", "{}");
+
+ // tensor map
+ assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
+ "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ // tensor map derivatives
+ assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
+ "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }",
+ "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }",
+ "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
+ assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }",
+ "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
+ assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }");
// tensor reduce
+ assertEvaluates("{ {}:16 }",
+ "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ // reduce composites
assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0");
assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0");
assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }");
@@ -114,23 +131,6 @@ public class EvaluationTestCase extends junit.framework.TestCase {
assertEvaluates("{ {}:16 }",
"sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
- // tensor map
- assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
- "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
- "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
- "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
- "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }",
- "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }",
- "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
- assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }",
- "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
- assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }");
-
// tensor join
assertEvaluates("{ }", "tensor0 * tensor0", "{}");
assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )",
@@ -156,6 +156,12 @@ public class EvaluationTestCase extends junit.framework.TestCase {
assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }",
"tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }");
+ assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
+ "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
+ "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
+ "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
// Combined
assertEvaluates(String.valueOf(7.5 + 45 + 1.7),