aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-05 22:49:08 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-05 22:49:08 +0100
commited8c274dc76794efa692efba6cf509b058b13648 (patch)
treec1dcb9fbc70b851be5cfdb8c335089283715f698 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
parent64c5daa351557869e64786188afa75ed3b59991b (diff)
Literal tensors with value expressions
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java41
1 files changed, 40 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index c1732aabf0b..e6e49e07c34 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -14,9 +15,13 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
import java.util.stream.Collectors;
/**
@@ -72,10 +77,44 @@ public class TensorFunctionNode extends CompositeNode {
return new TensorValue(function.evaluate(context));
}
- public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
+ public static TensorFunctionExpressionNode wrap(ExpressionNode node) {
return new TensorFunctionExpressionNode(node);
}
+ public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>();
+ for (var entry : nodes.entrySet())
+ closures.put(entry.getKey(), new ExpressionClosure(entry.getValue()));
+ return closures;
+ }
+
+ public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) {
+ List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>();
+ for (var entry : nodes)
+ closures.add(new ExpressionClosure(entry));
+ return closures;
+ }
+
+ private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> {
+
+ private final ExpressionNode expression;
+
+ public ExpressionClosure(ExpressionNode expression) {
+ this.expression = expression;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<?> context) {
+ return expression.evaluate((Context)context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return expression.toString();
+ }
+
+ }
+
/**
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.