diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-04 04:52:16 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-04 04:52:16 -0800 |
commit | a71001d66ada9eaf4ae89d896fea60a39ea2056b (patch) | |
tree | 4cd063b18eb095407df8a4e244564d8bfc690de5 /searchlib/src | |
parent | 78b29e76f2bab63f7cec92f4c1fd9e7661602df7 (diff) |
Propagate binding context to/from tensor functions
Diffstat (limited to 'searchlib/src')
6 files changed, 98 insertions, 26 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index f531d77762d..69304a811b1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -69,7 +69,7 @@ public class MapContext extends Context { * Sets the value of a key. The value is frozen by this. */ @Override - public void put(String key,Value value) { + public void put(String key, Value value) { bindings.put(key, value.freeze()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java index 084bfe65e06..83aabada8f0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -20,8 +20,7 @@ public class FunctionReferenceContext { private final ImmutableMap<String, ExpressionFunction> functions; /** Mapping from argument names to the expressions they resolve to */ - // TODO: Make private - public final Map<String, String> bindings = new HashMap<>(); + private final Map<String, String> bindings = new HashMap<>(); /** Create a context for a single serialization task */ public FunctionReferenceContext() { @@ -56,14 +55,12 @@ public class FunctionReferenceContext { return mapBuilder.build(); } - /** - * Returns a function or null if it isn't defined in this context - */ + /** Returns a function or null if it isn't defined in this context */ public ExpressionFunction getFunction(String name) { return functions.get(name); } - protected final ImmutableMap<String, ExpressionFunction> functions() { return functions; } + protected ImmutableMap<String, ExpressionFunction> functions() { return functions; } - /** Returns the resolution of an argument, or null if it isn't defined in this context */ + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ public String getBinding(String name) { return bindings.get(name); } /** Returns a new context with the bindings replaced by the given bindings */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 62b3379f635..8fec3603f3e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -68,7 +68,7 @@ public final class ReferenceNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - // A reference to a function argument? + // A reference to an identifier (function argument or bound variable)? if (reference.isIdentifier() && context.getBinding(getName()) != null) { // a bound identifier: replace by the value it is bound to return string.append(context.getBinding(getName())); @@ -89,6 +89,8 @@ public final class ReferenceNode extends CompositeNode { context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); return string.append("rankingExpression(").append(instance.getName()).append(')'); } + + // Not resolved in this context: output as-is return reference.toString(string, context, path, parent); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 4acc1a85490..d7807caa2b6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -8,7 +8,6 @@ import com.yahoo.tensor.TensorType; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; /** @@ -37,7 +36,7 @@ public class SerializationContext extends FunctionReferenceContext { } /** Create a context for a single serialization task */ - public SerializationContext(List<ExpressionFunction> functions, Map<String, String> bindings) { + public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) { this(functions, bindings, new LinkedHashMap<>()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index a0c261ae9d3..f510f38d7a7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -2,6 +2,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableMap; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -16,8 +18,6 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; -import java.sql.Ref; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.LinkedHashMap; @@ -71,7 +71,9 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext<Reference> context) { return function.type(context); } + public TensorType type(TypeContext<Reference> context) { + return function.type(context); + } @Override public Value evaluate(Context context) { @@ -117,9 +119,16 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.wrapped() != null) + outermost = outermost.wrapped(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent).toString(); } else { return expression.toString(); @@ -180,9 +189,17 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.wrapped() != null) + outermost = outermost.wrapped(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent) + .toString(); } else { return expression.toString(); @@ -192,22 +209,73 @@ public class TensorFunctionNode extends CompositeNode { } /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionToStringContext implements ToStringContext { + private static class ExpressionToStringContext extends SerializationContext implements ToStringContext { - final SerializationContext context; - final Deque<String> path; - final CompositeNode parent; + private final ToStringContext wrappedToStringContext; + private final SerializationContext wrappedSerializationContext; + private final Deque<String> path; + private final CompositeNode parent; public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null); - public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { - this.context = context; + ExpressionToStringContext(SerializationContext wrappedSerializationContext, Deque<String> path, CompositeNode parent) { + this(wrappedSerializationContext, null, path, parent); + } + + ExpressionToStringContext(SerializationContext wrappedSerializationContext, + ToStringContext wrappedToStringContext, + Deque<String> path, + CompositeNode parent) { + this.wrappedSerializationContext = wrappedSerializationContext; + this.wrappedToStringContext = wrappedToStringContext; this.path = path; this.parent = parent; } + /** Adds the serialization of a function */ + public void addFunctionSerialization(String name, String expressionString) { + wrappedSerializationContext.addFunctionSerialization(name, expressionString); + } + + /** Adds the serialization of the an argument type to a function */ + public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { + wrappedSerializationContext.addArgumentTypeSerialization(functionName, argumentName, type); + } + + /** Adds the serialization of the return type of a function */ + public void addFunctionTypeSerialization(String functionName, TensorType type) { + wrappedSerializationContext.addFunctionTypeSerialization(functionName, type); + } + + public Map<String, String> serializedFunctions() { + return wrappedSerializationContext.serializedFunctions(); + } + + /** Returns a function or null if it isn't defined in this context */ + public ExpressionFunction getFunction(String name) { return wrappedSerializationContext.getFunction(name); } + + protected ImmutableMap<String, ExpressionFunction> functions() { return wrappedSerializationContext.functions(); } + + public ToStringContext wrapped() { return wrappedToStringContext; } + + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ + @Override + public String getBinding(String name) { + if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null) + return wrappedToStringContext.getBinding(name); + else + return wrappedSerializationContext.getBinding(name); + } + + /** Returns a new context with the bindings replaced by the given bindings */ + @Override + public ExpressionToStringContext withBindings(Map<String, String> bindings) { + return new ExpressionToStringContext(new SerializationContext(wrappedSerializationContext.functions().values(), bindings), + wrappedToStringContext, path, parent); + } + } /** Turns an EvaluationContext into a Context */ diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 7eb1fecc0cb..e3d3ac7b2e1 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -198,6 +198,12 @@ public class RankingExpressionTestCase { assertSerialization(List.of("tensor(x{}):{{x:foo}:rankingExpression(scalarFunction),{x:bar}:rankingExpression(scalarFunction)}"), "tensor(x{}):{{x:foo}:scalarFunction(), {x:bar}:scalarFunction()}", functions, false); + + // Shadowing + assertSerialization(List.of("tensor(scalarFunction[1])(rankingExpression(tensorFunction){x:scalarFunction + rankingExpression(scalarFunction)})"), + "tensor(scalarFunction[1])(tensorFunction{x: scalarFunction + scalarFunction()})", + functions, false); + } @Test |