diff options
9 files changed, 146 insertions, 39 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index a54e21aae68..2be3022ce6e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -108,7 +108,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement if (FeatureNames.isSimpleFeature(reference)) { // The argument may be a local identifier bound to the actual value String argument = reference.simpleArgument().get(); - reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); + String argumentBinding = getBinding(argument); + reference = Reference.simple(reference.name(), argumentBinding != null ? argumentBinding : argument); return featureTypes.get(reference); } @@ -152,7 +153,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement private Optional<String> boundIdentifier(Reference reference) { if ( ! reference.arguments().isEmpty()) return Optional.empty(); if ( reference.output() != null) return Optional.empty(); - return Optional.ofNullable(bindings.get(reference.name())); + return Optional.ofNullable(getBinding(reference.name())); } private Optional<ExpressionFunction> functionInvocation(Reference reference) { @@ -203,8 +204,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement Map<String, String> bindings = new HashMap<>(formalArguments.size()); for (int i = 0; i < formalArguments.size(); i++) { String identifier = invocationArguments.expressions().get(i).toString(); - identifier = super.bindings.getOrDefault(identifier, identifier); - bindings.put(formalArguments.get(i), identifier); + String identifierBinding = super.getBinding(identifier); + bindings.put(formalArguments.get(i), identifierBinding != null ? identifierBinding : identifier); } return bindings; } @@ -215,7 +216,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement @Override public MapEvaluationTypeContext withBindings(Map<String, String> bindings) { - if (bindings.isEmpty() && this.bindings.isEmpty()) return this; return new MapEvaluationTypeContext(functions(), bindings, featureTypes, currentResolutionCallStack); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index f531d77762d..69304a811b1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -69,7 +69,7 @@ public class MapContext extends Context { * Sets the value of a key. The value is frozen by this. */ @Override - public void put(String key,Value value) { + public void put(String key, Value value) { bindings.put(key, value.freeze()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java index 084bfe65e06..83aabada8f0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -20,8 +20,7 @@ public class FunctionReferenceContext { private final ImmutableMap<String, ExpressionFunction> functions; /** Mapping from argument names to the expressions they resolve to */ - // TODO: Make private - public final Map<String, String> bindings = new HashMap<>(); + private final Map<String, String> bindings = new HashMap<>(); /** Create a context for a single serialization task */ public FunctionReferenceContext() { @@ -56,14 +55,12 @@ public class FunctionReferenceContext { return mapBuilder.build(); } - /** - * Returns a function or null if it isn't defined in this context - */ + /** Returns a function or null if it isn't defined in this context */ public ExpressionFunction getFunction(String name) { return functions.get(name); } - protected final ImmutableMap<String, ExpressionFunction> functions() { return functions; } + protected ImmutableMap<String, ExpressionFunction> functions() { return functions; } - /** Returns the resolution of an argument, or null if it isn't defined in this context */ + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ public String getBinding(String name) { return bindings.get(name); } /** Returns a new context with the bindings replaced by the given bindings */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 62b3379f635..8fec3603f3e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -68,7 +68,7 @@ public final class ReferenceNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - // A reference to a function argument? + // A reference to an identifier (function argument or bound variable)? if (reference.isIdentifier() && context.getBinding(getName()) != null) { // a bound identifier: replace by the value it is bound to return string.append(context.getBinding(getName())); @@ -89,6 +89,8 @@ public final class ReferenceNode extends CompositeNode { context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); return string.append("rankingExpression(").append(instance.getName()).append(')'); } + + // Not resolved in this context: output as-is return reference.toString(string, context, path, parent); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 4acc1a85490..d7807caa2b6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -8,7 +8,6 @@ import com.yahoo.tensor.TensorType; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; /** @@ -37,7 +36,7 @@ public class SerializationContext extends FunctionReferenceContext { } /** Create a context for a single serialization task */ - public SerializationContext(List<ExpressionFunction> functions, Map<String, String> bindings) { + public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) { this(functions, bindings, new LinkedHashMap<>()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index a0c261ae9d3..f510f38d7a7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -2,6 +2,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableMap; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -16,8 +18,6 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; -import java.sql.Ref; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.LinkedHashMap; @@ -71,7 +71,9 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext<Reference> context) { return function.type(context); } + public TensorType type(TypeContext<Reference> context) { + return function.type(context); + } @Override public Value evaluate(Context context) { @@ -117,9 +119,16 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.wrapped() != null) + outermost = outermost.wrapped(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent).toString(); } else { return expression.toString(); @@ -180,9 +189,17 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(), context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.wrapped() != null) + outermost = outermost.wrapped(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent) + .toString(); } else { return expression.toString(); @@ -192,22 +209,73 @@ public class TensorFunctionNode extends CompositeNode { } /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionToStringContext implements ToStringContext { + private static class ExpressionToStringContext extends SerializationContext implements ToStringContext { - final SerializationContext context; - final Deque<String> path; - final CompositeNode parent; + private final ToStringContext wrappedToStringContext; + private final SerializationContext wrappedSerializationContext; + private final Deque<String> path; + private final CompositeNode parent; public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null); - public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { - this.context = context; + ExpressionToStringContext(SerializationContext wrappedSerializationContext, Deque<String> path, CompositeNode parent) { + this(wrappedSerializationContext, null, path, parent); + } + + ExpressionToStringContext(SerializationContext wrappedSerializationContext, + ToStringContext wrappedToStringContext, + Deque<String> path, + CompositeNode parent) { + this.wrappedSerializationContext = wrappedSerializationContext; + this.wrappedToStringContext = wrappedToStringContext; this.path = path; this.parent = parent; } + /** Adds the serialization of a function */ + public void addFunctionSerialization(String name, String expressionString) { + wrappedSerializationContext.addFunctionSerialization(name, expressionString); + } + + /** Adds the serialization of the an argument type to a function */ + public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { + wrappedSerializationContext.addArgumentTypeSerialization(functionName, argumentName, type); + } + + /** Adds the serialization of the return type of a function */ + public void addFunctionTypeSerialization(String functionName, TensorType type) { + wrappedSerializationContext.addFunctionTypeSerialization(functionName, type); + } + + public Map<String, String> serializedFunctions() { + return wrappedSerializationContext.serializedFunctions(); + } + + /** Returns a function or null if it isn't defined in this context */ + public ExpressionFunction getFunction(String name) { return wrappedSerializationContext.getFunction(name); } + + protected ImmutableMap<String, ExpressionFunction> functions() { return wrappedSerializationContext.functions(); } + + public ToStringContext wrapped() { return wrappedToStringContext; } + + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ + @Override + public String getBinding(String name) { + if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null) + return wrappedToStringContext.getBinding(name); + else + return wrappedSerializationContext.getBinding(name); + } + + /** Returns a new context with the bindings replaced by the given bindings */ + @Override + public ExpressionToStringContext withBindings(Map<String, String> bindings) { + return new ExpressionToStringContext(new SerializationContext(wrappedSerializationContext.functions().values(), bindings), + wrappedToStringContext, path, parent); + } + } /** Turns an EvaluationContext into a Context */ diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 7eb1fecc0cb..e3d3ac7b2e1 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -198,6 +198,12 @@ public class RankingExpressionTestCase { assertSerialization(List.of("tensor(x{}):{{x:foo}:rankingExpression(scalarFunction),{x:bar}:rankingExpression(scalarFunction)}"), "tensor(x{}):{{x:foo}:scalarFunction(), {x:bar}:scalarFunction()}", functions, false); + + // Shadowing + assertSerialization(List.of("tensor(scalarFunction[1])(rankingExpression(tensorFunction){x:scalarFunction + rankingExpression(scalarFunction)})"), + "tensor(scalarFunction[1])(tensorFunction{x: scalarFunction + scalarFunction()})", + functions, false); + } @Test diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e5095178be7..ac6621ce78b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -91,7 +91,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); - GenerateContext generateContext = new GenerateContext(type, context); + GenerateEvaluationContext generateContext = new GenerateEvaluationContext(type, context); for (int i = 0; i < indexes.size(); i++) { indexes.next(); builder.cell(generateContext.apply(indexes), indexes.indexesForReading()); @@ -113,7 +113,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM if (freeGenerator != null) return freeGenerator.toString(); else - return boundGenerator.toString(context); + return boundGenerator.toString(new GenerateToStringContext(context)); } /** @@ -121,19 +121,18 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM * This returns all the current index values as variables and falls back to delivering from the given * evaluation context. */ - private class GenerateContext implements EvaluationContext<NAMETYPE> { + private class GenerateEvaluationContext implements EvaluationContext<NAMETYPE> { private final TensorType type; private final EvaluationContext<NAMETYPE> context; private IndexedTensor.Indexes indexes; - GenerateContext(TensorType type, EvaluationContext<NAMETYPE> context) { + GenerateEvaluationContext(TensorType type, EvaluationContext<NAMETYPE> context) { this.type = type; this.context = context; } - @SuppressWarnings("unchecked") double apply(IndexedTensor.Indexes indexes) { if (freeGenerator != null) { return freeGenerator.apply(indexes.toList()); @@ -173,4 +172,26 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } + /** A context which adds the bindings of the generate dimension names to the given context. */ + private class GenerateToStringContext implements ToStringContext { + + private final ToStringContext context; + + public GenerateToStringContext(ToStringContext context) { + this.context = context; + } + + @Override + public String getBinding(String identifier) { + if (type.dimension(identifier).isPresent()) + return identifier; // dimension names are bound but not substituted in the generate context + else + return context.getBinding(identifier); + } + + @Override + public ToStringContext wrapped() { return context; } + + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java index cb7f376c365..c09631d36d7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -3,13 +3,27 @@ package com.yahoo.tensor.functions; /** * A context which is passed down to all nested functions when returning a string representation. - * The default implementation is empty as this library does not in itself have any need for a - * context. * * @author bratseth */ public interface ToStringContext { - static ToStringContext empty() { return new ToStringContext() {}; } + static ToStringContext empty() { return new EmptyStringContext(); } + + /** Returns the name an identifier is bound to, or null if not bound in this context */ + String getBinding(String name); + + /** Returns another context this wraps, or null if none is wrapped */ + ToStringContext wrapped(); + + class EmptyStringContext implements ToStringContext { + + @Override + public String getBinding(String name) { return null; } + + @Override + public ToStringContext wrapped() { return null; } + + } } |