diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-05 22:49:08 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-05 22:49:08 +0100 |
commit | ed8c274dc76794efa692efba6cf509b058b13648 (patch) | |
tree | c1dcb9fbc70b851be5cfdb8c335089283715f698 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | |
parent | 64c5daa351557869e64786188afa75ed3b59991b (diff) |
Literal tensors with value expressions
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | 41 |
1 files changed, 40 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index c1732aabf0b..e6e49e07c34 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.TypeContext; @@ -14,9 +15,13 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -72,10 +77,44 @@ public class TensorFunctionNode extends CompositeNode { return new TensorValue(function.evaluate(context)); } - public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { + public static TensorFunctionExpressionNode wrap(ExpressionNode node) { return new TensorFunctionExpressionNode(node); } + public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>(); + for (var entry : nodes.entrySet()) + closures.put(entry.getKey(), new ExpressionClosure(entry.getValue())); + return closures; + } + + public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) { + List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>(); + for (var entry : nodes) + closures.add(new ExpressionClosure(entry)); + return closures; + } + + private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> { + + private final ExpressionNode expression; + + public ExpressionClosure(ExpressionNode expression) { + this.expression = expression; + } + + @Override + public Double apply(EvaluationContext<?> context) { + return expression.evaluate((Context)context).asDouble(); + } + + @Override + public String toString() { + return expression.toString(); + } + + } + /** * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. |