diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-04 04:52:16 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-04 04:52:16 -0800 |
commit | a71001d66ada9eaf4ae89d896fea60a39ea2056b (patch) | |
tree | 4cd063b18eb095407df8a4e244564d8bfc690de5 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | |
parent | 78b29e76f2bab63f7cec92f4c1fd9e7661602df7 (diff) |
Propagate binding context to/from tensor functions
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | 98 |
1 files changed, 83 insertions, 15 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index a0c261ae9d3..f510f38d7a7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -2,6 +2,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableMap; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -16,8 +18,6 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; -import java.sql.Ref; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.LinkedHashMap; @@ -71,7 +71,9 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext<Reference> context) { return function.type(context); } + public TensorType type(TypeContext<Reference> context) { + return function.type(context); + } @Override public Value evaluate(Context context) { @@ -117,9 +119,16 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.wrapped() != null) + outermost = outermost.wrapped(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent).toString(); } else { return expression.toString(); @@ -180,9 +189,17 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.wrapped() != null) + outermost = outermost.wrapped(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent) + .toString(); } else { return expression.toString(); @@ -192,22 +209,73 @@ public class TensorFunctionNode extends CompositeNode { } /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionToStringContext implements ToStringContext { + private static class ExpressionToStringContext extends SerializationContext implements ToStringContext { - final SerializationContext context; - final Deque<String> path; - final CompositeNode parent; + private final ToStringContext wrappedToStringContext; + private final SerializationContext wrappedSerializationContext; + private final Deque<String> path; + private final CompositeNode parent; public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null); - public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { - this.context = context; + ExpressionToStringContext(SerializationContext wrappedSerializationContext, Deque<String> path, CompositeNode parent) { + this(wrappedSerializationContext, null, path, parent); + } + + ExpressionToStringContext(SerializationContext wrappedSerializationContext, + ToStringContext wrappedToStringContext, + Deque<String> path, + CompositeNode parent) { + this.wrappedSerializationContext = wrappedSerializationContext; + this.wrappedToStringContext = wrappedToStringContext; this.path = path; this.parent = parent; } + /** Adds the serialization of a function */ + public void addFunctionSerialization(String name, String expressionString) { + wrappedSerializationContext.addFunctionSerialization(name, expressionString); + } + + /** Adds the serialization of the an argument type to a function */ + public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { + wrappedSerializationContext.addArgumentTypeSerialization(functionName, argumentName, type); + } + + /** Adds the serialization of the return type of a function */ + public void addFunctionTypeSerialization(String functionName, TensorType type) { + wrappedSerializationContext.addFunctionTypeSerialization(functionName, type); + } + + public Map<String, String> serializedFunctions() { + return wrappedSerializationContext.serializedFunctions(); + } + + /** Returns a function or null if it isn't defined in this context */ + public ExpressionFunction getFunction(String name) { return wrappedSerializationContext.getFunction(name); } + + protected ImmutableMap<String, ExpressionFunction> functions() { return wrappedSerializationContext.functions(); } + + public ToStringContext wrapped() { return wrappedToStringContext; } + + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ + @Override + public String getBinding(String name) { + if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null) + return wrappedToStringContext.getBinding(name); + else + return wrappedSerializationContext.getBinding(name); + } + + /** Returns a new context with the bindings replaced by the given bindings */ + @Override + public ExpressionToStringContext withBindings(Map<String, String> bindings) { + return new ExpressionToStringContext(new SerializationContext(wrappedSerializationContext.functions().values(), bindings), + wrappedToStringContext, path, parent); + } + } /** Turns an EvaluationContext into a Context */ |