diff options
Diffstat (limited to 'searchlib/src/main/java/com/yahoo')
6 files changed, 132 insertions, 24 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 674571ff73e..f2f8799b342 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -134,6 +134,8 @@ public class ExpressionFunction { for (int i = 0; i < arguments.size() && i < argumentValues.size(); ++i) { argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(new StringBuilder(), context, path, null).toString()); } + String symbol = toSymbol(argumentBindings); + System.out.println("Expanding function " + symbol); return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java index 83aabada8f0..9d094ce06f4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -22,6 +22,8 @@ public class FunctionReferenceContext { /** Mapping from argument names to the expressions they resolve to */ private final Map<String, String> bindings = new HashMap<>(); + private final FunctionReferenceContext parent; + /** Create a context for a single serialization task */ public FunctionReferenceContext() { this(Collections.emptyList()); @@ -43,9 +45,14 @@ public class FunctionReferenceContext { /** Create a context for a single serialization task */ public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) { + this(functions, bindings, null); + } + + public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings, FunctionReferenceContext parent) { this.functions = ImmutableMap.copyOf(functions); if (bindings != null) this.bindings.putAll(bindings); + this.parent = parent; } private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) { @@ -56,16 +63,34 @@ public class FunctionReferenceContext { } /** Returns a function or null if it isn't defined in this context */ - public ExpressionFunction getFunction(String name) { return functions.get(name); } + public ExpressionFunction getFunction(String name) { + ExpressionFunction function = functions.get(name); + if (function != null) { + return function; + } + if (parent != null) { + return parent.getFunction(name); + } + return null; + } protected ImmutableMap<String, ExpressionFunction> functions() { return functions; } /** Returns the resolution of an identifier, or null if it isn't defined in this context */ - public String getBinding(String name) { return bindings.get(name); } + public String getBinding(String name) { + String binding = bindings.get(name); + if (binding != null) { + return binding; + } + if (parent != null) { + return parent.getBinding(name); + } + return null; + } /** Returns a new context with the bindings replaced by the given bindings */ public FunctionReferenceContext withBindings(Map<String, String> bindings) { - return new FunctionReferenceContext(this.functions, bindings); + return new FunctionReferenceContext(this.functions, bindings, this); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 8fec3603f3e..a994f5247b7 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -74,20 +74,49 @@ public final class ReferenceNode extends CompositeNode { return string.append(context.getBinding(getName())); } + String name = getName(); // A reference to a function? ExpressionFunction function = context.getFunction(getName()); if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) { // a function reference: replace by the referenced function wrapped in rankingExpression - if (path == null) - path = new ArrayDeque<>(); - String myPath = getName() + getArguments().expressions(); - if (path.contains(myPath)) - throw new IllegalStateException("Cycle in ranking expression function: " + path); - path.addLast(myPath); - ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); - path.removeLast(); - context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); - return string.append("rankingExpression(").append(instance.getName()).append(')'); +// if (path == null) +// path = new ArrayDeque<>(); +// String myPath = getName() + getArguments().expressions(); +// if (path.contains(myPath)) +// throw new IllegalStateException("Cycle in ranking expression function: " + path); +// path.addLast(myPath); +// ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); +// path.removeLast(); +// context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); +// return string.append("rankingExpression(").append(instance.getName()).append(')'); + +// return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString()); + + // hack for testing: + // So, this worked. Meaning that when expanding we could probably cut down on the context tree? +// String expression = function.getBody().toString(); +// context.addFunctionSerialization(RankingExpression.propertyName(getName()), expression); // <- actually set by deriveFunctionProperties - this will only overwrite + + String prefix = string.toString(); // incredibly ugly hack - for testing this + + // so problem here with input values + if (prefix.startsWith("attribute")) { + if (name.equals("segment_ids") || name.equals("input_mask") || name.equals("input_ids")) { + return string.append(getName()); + // TODO: divine this! + } + } + + // so, in one case +// rankprofile[2].fef.property[35].name "rankingExpression(imported_ml_function_bertsquad8_input_ids).rankingScript" +// rankprofile[2].fef.property[35].value "input_ids" + // vs +// rankprofile[2].fef.property[2].name "rankingExpression(input_ids).rankingScript" +// rankprofile[2].fef.property[2].value "attribute(input_ids)" + // uppermost is wrong, then we need the below + + return string.append("rankingExpression(").append(getName()).append(')'); + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index d7807caa2b6..c79f5556e03 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -50,7 +50,7 @@ public class SerializationContext extends FunctionReferenceContext { */ public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings, Map<String, String> serializedFunctions) { - this(toMap(functions), bindings, serializedFunctions); + this(toMap(functions), bindings, serializedFunctions, null); } private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) { @@ -69,8 +69,8 @@ public class SerializationContext extends FunctionReferenceContext { * is <b>transferred</b> to this and will be modified in it */ public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings, - Map<String, String> serializedFunctions) { - super(functions, bindings); + Map<String, String> serializedFunctions, FunctionReferenceContext root) { + super(functions, bindings, root); this.serializedFunctions = serializedFunctions; } @@ -92,7 +92,7 @@ public class SerializationContext extends FunctionReferenceContext { @Override public SerializationContext withBindings(Map<String, String> bindings) { - return new SerializationContext(functions(), bindings, this.serializedFunctions); + return new SerializationContext(functions(), bindings, this.serializedFunctions, this); } public Map<String, String> serializedFunctions() { return serializedFunctions; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 6e1cdf52158..1ab9702367a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -143,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode { return new ExpressionScalarFunction(node); } - private static class ExpressionScalarFunction implements ScalarFunction<Reference> { + public static class ExpressionScalarFunction implements ScalarFunction<Reference> { private final ExpressionNode expression; @@ -151,6 +151,10 @@ public class TensorFunctionNode extends CompositeNode { this.expression = expression; } + public ExpressionNode getExpression() { + return expression; + } + @Override public Double apply(EvaluationContext<Reference> context) { return expression.evaluate(new ContextWrapper(context)).asDouble(); @@ -321,13 +325,45 @@ public class TensorFunctionNode extends CompositeNode { public ToStringContext parent() { return wrappedToStringContext; } + private int contextNodes() { + int nodes = 0; + if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) { + nodes += ((ExpressionToStringContext)wrappedToStringContext).contextNodes(); + } else if (wrappedToStringContext != null) { + nodes += 1; + } + if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) { + nodes += ((ExpressionToStringContext)wrappedSerializationContext).contextNodes(); + } else if (wrappedSerializationContext != null) { + nodes += 1; + } + return nodes + 1; + } + + private int contextDepth() { + int depth = 0; + if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) { + depth += ((ExpressionToStringContext)wrappedToStringContext).contextDepth(); + } + if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) { + int d = ((ExpressionToStringContext)wrappedSerializationContext).contextDepth(); + depth = Math.max(depth, d); + } + return depth + 1; + } + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ @Override public String getBinding(String name) { - if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null) - return wrappedToStringContext.getBinding(name); - else - return wrappedSerializationContext.getBinding(name); +// System.out.println("getBinding for " + name + " with node count " + contextNodes() + " and max depth " + contextDepth()); + String binding; + if (wrappedToStringContext != null) { + binding = wrappedToStringContext.getBinding(name); + if (binding != null) { + return binding; + } + } + return wrappedSerializationContext.getBinding(name); } /** Returns a new context with the bindings replaced by the given bindings */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index a541eac2421..95652bb0e15 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -28,8 +29,16 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext return node; } + /** Returns true if the given reference is an attribute, constant or query feature */ + // TEMP: from config-model module + public static boolean isSimpleFeature(Reference reference) { + if ( ! reference.isSimple()) return false; + String name = reference.name(); + return name.equals("attribute") || name.equals("constant") || name.equals("query"); + } + private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) { - if (!node.getArguments().isEmpty()) + if ( ! node.getArguments().isEmpty() && ! isSimpleFeature(node.reference())) return transformArguments(node, context); else return transformConstantReference(node, context); @@ -44,7 +53,14 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext } private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) { - Value value = context.constants().get(node.getName()); + String name = node.getName(); + if (node.reference().name().equals("constant")) { + ExpressionNode arg = node.getArguments().expressions().get(0); + if (arg instanceof ReferenceNode) { + name = ((ReferenceNode)arg).getName(); + } + } + Value value = context.constants().get(name); // works if "constant(...)" is added if (value == null || value.type().rank() > 0) { return node; // not a number constant reference } |