summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:18:01 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-24 14:18:01 +0100
commitcb2dc3460fa31dffb51e54847283038e8a0ae93c (patch)
treee96497fe6b167f8867ad9cb225ea979a6e09dab8 /searchlib/src/main/java/com
parent437a2dc519cc991302c01acb8cd1df1e96b1283d (diff)
Implement composite functions
Diffstat (limited to 'searchlib/src/main/java/com')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java112
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java1
6 files changed, 131 insertions, 6 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index 4cd1e1fc0ee..620c6fad0b4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.tensor.functions.EvaluationContext;
import java.util.Set;
@@ -10,7 +11,7 @@ import java.util.Set;
*
* @author bratseth
*/
-public abstract class Context {
+public abstract class Context implements EvaluationContext {
/**
* <p>Returns the value of a simple variable name.</p>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index f559f8adaf5..dc422f2c8da 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -97,6 +97,14 @@ public class TensorValue extends Value {
return new TensorValue(value.max(asTensor(argument, "max")));
}
+ public Value atan2(Value argument) {
+ return new TensorValue(value.atan2(asTensor(argument, "atan2")));
+ }
+
+ public Value equal(Value argument) {
+ return new TensorValue(value.equal(asTensor(argument, "equal")));
+ }
+
public Value sum(String dimension) {
return new TensorValue(value.sum(Collections.singletonList(dimension)));
}
@@ -129,6 +137,10 @@ public class TensorValue extends Value {
return min(argument);
else if (function.equals(Function.max) && argument instanceof TensorValue)
return max(argument);
+ else if (function.equals(Function.atan2) && argument instanceof TensorValue)
+ return atan2(argument);
+ else if (function.equals(Function.equal) && argument instanceof TensorValue)
+ return equal(argument);
else
return new TensorValue(value.map((value) -> function.evaluate(value, argument.asDouble())));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index cbea2ad627e..b8e48dc2f05 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -43,7 +43,8 @@ public enum Function implements Serializable {
max(2) { public double evaluate(double x, double y) { return max(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
mod(2) { public double evaluate(double x, double y) { return x % y; } },
- pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
+ pow(2) { public double evaluate(double x, double y) { return pow(x,y); } },
+ equal(2) { public double evaluate(double x, double y) { return x==y ? 1.0 : 0.0; } };
private final int arity;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
new file mode 100644
index 00000000000..ebd79f65578
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -0,0 +1,112 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.EvaluationContext;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * A node which performs a tensor function
+ *
+ * @author bratseth
+ */
+ @Beta
+public class TensorFunctionNode extends CompositeNode {
+
+ private final TensorFunction function;
+
+ public TensorFunctionNode(TensorFunction function) {
+ this.function = function;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return function.functionArguments().stream()
+ .map(f -> ((TensorFunctionExpressionNode)f).expression)
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ throw new UnsupportedOperationException("Not implemented");
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ // Serialize as primitive
+ return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ return new TensorValue(function.evaluate(context));
+ }
+
+ public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
+ return new TensorFunctionExpressionNode(node);
+ }
+
+ /**
+ * A tensor function implemented by an expression.
+ * This allows us to pass expressions as tensor function arguments.
+ */
+ public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {
+
+ /** An expression which produces a tensor */
+ private final ExpressionNode expression;
+
+ public TensorFunctionExpressionNode(ExpressionNode expression) {
+ this.expression = expression;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext context) {
+ Value result = expression.evaluate((Context)context);
+ if ( ! ( result instanceof TensorValue))
+ throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
+ "but this returns " + result + ", not a tensor");
+ return ((TensorValue)result).asTensor();
+ }
+
+ @Override
+ public String toString(ToStringContext c) {
+ ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c;
+ return expression.toString(context.context, context.path, context.parent);
+ }
+
+ }
+
+ /** Allows passing serialization context arguments through TensorFunctions */
+ private static class ExpressionNodeToStringContext implements ToStringContext {
+
+ final SerializationContext context;
+ final Deque<String> path;
+ final CompositeNode parent;
+
+ public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ this.context = context;
+ this.path = path;
+ this.parent = parent;
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
index 4f73c632422..d4b95d12fdd 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorReduceNode.java
@@ -6,7 +6,7 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.functions.ReduceFunction;
+import com.yahoo.tensor.functions.Reduce;
import java.util.Collections;
import java.util.Deque;
@@ -23,12 +23,12 @@ public class TensorReduceNode extends CompositeNode {
/** The tensor to aggregate over */
private final ExpressionNode argument;
- private final ReduceFunction.Aggregator aggregator;
+ private final Reduce.Aggregator aggregator;
/** The dimensions to sum over, or empty to sum all cells */
private final ImmutableList<String> dimensions;
- public TensorReduceNode(ExpressionNode argument, ReduceFunction.Aggregator aggregator, List<String> dimensions) {
+ public TensorReduceNode(ExpressionNode argument, Reduce.Aggregator aggregator, List<String> dimensions) {
this.argument = argument;
this.aggregator = aggregator;
this.dimensions = ImmutableList.copyOf(dimensions);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java
index 17a08beba8b..b7f21c215dc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorRenameNode.java
@@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.functions.ReduceFunction;
import java.util.Collections;
import java.util.Deque;