aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java31
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java49
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java46
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java20
6 files changed, 132 insertions, 24 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 674571ff73e..f2f8799b342 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -134,6 +134,8 @@ public class ExpressionFunction {
for (int i = 0; i < arguments.size() && i < argumentValues.size(); ++i) {
argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(new StringBuilder(), context, path, null).toString());
}
+ String symbol = toSymbol(argumentBindings);
+ System.out.println("Expanding function " + symbol);
return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
index 83aabada8f0..9d094ce06f4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
@@ -22,6 +22,8 @@ public class FunctionReferenceContext {
/** Mapping from argument names to the expressions they resolve to */
private final Map<String, String> bindings = new HashMap<>();
+ private final FunctionReferenceContext parent;
+
/** Create a context for a single serialization task */
public FunctionReferenceContext() {
this(Collections.emptyList());
@@ -43,9 +45,14 @@ public class FunctionReferenceContext {
/** Create a context for a single serialization task */
public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) {
+ this(functions, bindings, null);
+ }
+
+ public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings, FunctionReferenceContext parent) {
this.functions = ImmutableMap.copyOf(functions);
if (bindings != null)
this.bindings.putAll(bindings);
+ this.parent = parent;
}
private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) {
@@ -56,16 +63,34 @@ public class FunctionReferenceContext {
}
/** Returns a function or null if it isn't defined in this context */
- public ExpressionFunction getFunction(String name) { return functions.get(name); }
+ public ExpressionFunction getFunction(String name) {
+ ExpressionFunction function = functions.get(name);
+ if (function != null) {
+ return function;
+ }
+ if (parent != null) {
+ return parent.getFunction(name);
+ }
+ return null;
+ }
protected ImmutableMap<String, ExpressionFunction> functions() { return functions; }
/** Returns the resolution of an identifier, or null if it isn't defined in this context */
- public String getBinding(String name) { return bindings.get(name); }
+ public String getBinding(String name) {
+ String binding = bindings.get(name);
+ if (binding != null) {
+ return binding;
+ }
+ if (parent != null) {
+ return parent.getBinding(name);
+ }
+ return null;
+ }
/** Returns a new context with the bindings replaced by the given bindings */
public FunctionReferenceContext withBindings(Map<String, String> bindings) {
- return new FunctionReferenceContext(this.functions, bindings);
+ return new FunctionReferenceContext(this.functions, bindings, this);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 8fec3603f3e..a994f5247b7 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -74,20 +74,49 @@ public final class ReferenceNode extends CompositeNode {
return string.append(context.getBinding(getName()));
}
+ String name = getName();
// A reference to a function?
ExpressionFunction function = context.getFunction(getName());
if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) {
// a function reference: replace by the referenced function wrapped in rankingExpression
- if (path == null)
- path = new ArrayDeque<>();
- String myPath = getName() + getArguments().expressions();
- if (path.contains(myPath))
- throw new IllegalStateException("Cycle in ranking expression function: " + path);
- path.addLast(myPath);
- ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path);
- path.removeLast();
- context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
- return string.append("rankingExpression(").append(instance.getName()).append(')');
+// if (path == null)
+// path = new ArrayDeque<>();
+// String myPath = getName() + getArguments().expressions();
+// if (path.contains(myPath))
+// throw new IllegalStateException("Cycle in ranking expression function: " + path);
+// path.addLast(myPath);
+// ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path);
+// path.removeLast();
+// context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
+// return string.append("rankingExpression(").append(instance.getName()).append(')');
+
+// return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString());
+
+ // hack for testing:
+ // So, this worked. Meaning that when expanding we could probably cut down on the context tree?
+// String expression = function.getBody().toString();
+// context.addFunctionSerialization(RankingExpression.propertyName(getName()), expression); // <- actually set by deriveFunctionProperties - this will only overwrite
+
+ String prefix = string.toString(); // incredibly ugly hack - for testing this
+
+ // so problem here with input values
+ if (prefix.startsWith("attribute")) {
+ if (name.equals("segment_ids") || name.equals("input_mask") || name.equals("input_ids")) {
+ return string.append(getName());
+ // TODO: divine this!
+ }
+ }
+
+ // so, in one case
+// rankprofile[2].fef.property[35].name "rankingExpression(imported_ml_function_bertsquad8_input_ids).rankingScript"
+// rankprofile[2].fef.property[35].value "input_ids"
+ // vs
+// rankprofile[2].fef.property[2].name "rankingExpression(input_ids).rankingScript"
+// rankprofile[2].fef.property[2].value "attribute(input_ids)"
+ // uppermost is wrong, then we need the below
+
+ return string.append("rankingExpression(").append(getName()).append(')');
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index d7807caa2b6..c79f5556e03 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -50,7 +50,7 @@ public class SerializationContext extends FunctionReferenceContext {
*/
public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings,
Map<String, String> serializedFunctions) {
- this(toMap(functions), bindings, serializedFunctions);
+ this(toMap(functions), bindings, serializedFunctions, null);
}
private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) {
@@ -69,8 +69,8 @@ public class SerializationContext extends FunctionReferenceContext {
* is <b>transferred</b> to this and will be modified in it
*/
public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings,
- Map<String, String> serializedFunctions) {
- super(functions, bindings);
+ Map<String, String> serializedFunctions, FunctionReferenceContext root) {
+ super(functions, bindings, root);
this.serializedFunctions = serializedFunctions;
}
@@ -92,7 +92,7 @@ public class SerializationContext extends FunctionReferenceContext {
@Override
public SerializationContext withBindings(Map<String, String> bindings) {
- return new SerializationContext(functions(), bindings, this.serializedFunctions);
+ return new SerializationContext(functions(), bindings, this.serializedFunctions, this);
}
public Map<String, String> serializedFunctions() { return serializedFunctions; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 6e1cdf52158..1ab9702367a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -143,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode {
return new ExpressionScalarFunction(node);
}
- private static class ExpressionScalarFunction implements ScalarFunction<Reference> {
+ public static class ExpressionScalarFunction implements ScalarFunction<Reference> {
private final ExpressionNode expression;
@@ -151,6 +151,10 @@ public class TensorFunctionNode extends CompositeNode {
this.expression = expression;
}
+ public ExpressionNode getExpression() {
+ return expression;
+ }
+
@Override
public Double apply(EvaluationContext<Reference> context) {
return expression.evaluate(new ContextWrapper(context)).asDouble();
@@ -321,13 +325,45 @@ public class TensorFunctionNode extends CompositeNode {
public ToStringContext parent() { return wrappedToStringContext; }
+ private int contextNodes() {
+ int nodes = 0;
+ if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) {
+ nodes += ((ExpressionToStringContext)wrappedToStringContext).contextNodes();
+ } else if (wrappedToStringContext != null) {
+ nodes += 1;
+ }
+ if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) {
+ nodes += ((ExpressionToStringContext)wrappedSerializationContext).contextNodes();
+ } else if (wrappedSerializationContext != null) {
+ nodes += 1;
+ }
+ return nodes + 1;
+ }
+
+ private int contextDepth() {
+ int depth = 0;
+ if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) {
+ depth += ((ExpressionToStringContext)wrappedToStringContext).contextDepth();
+ }
+ if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) {
+ int d = ((ExpressionToStringContext)wrappedSerializationContext).contextDepth();
+ depth = Math.max(depth, d);
+ }
+ return depth + 1;
+ }
+
/** Returns the resolution of an identifier, or null if it isn't defined in this context */
@Override
public String getBinding(String name) {
- if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null)
- return wrappedToStringContext.getBinding(name);
- else
- return wrappedSerializationContext.getBinding(name);
+// System.out.println("getBinding for " + name + " with node count " + contextNodes() + " and max depth " + contextDepth());
+ String binding;
+ if (wrappedToStringContext != null) {
+ binding = wrappedToStringContext.getBinding(name);
+ if (binding != null) {
+ return binding;
+ }
+ }
+ return wrappedSerializationContext.getBinding(name);
}
/** Returns a new context with the bindings replaced by the given bindings */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
index a541eac2421..95652bb0e15 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.transform;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -28,8 +29,16 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext
return node;
}
+ /** Returns true if the given reference is an attribute, constant or query feature */
+ // TEMP: from config-model module
+ public static boolean isSimpleFeature(Reference reference) {
+ if ( ! reference.isSimple()) return false;
+ String name = reference.name();
+ return name.equals("attribute") || name.equals("constant") || name.equals("query");
+ }
+
private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) {
- if (!node.getArguments().isEmpty())
+ if ( ! node.getArguments().isEmpty() && ! isSimpleFeature(node.reference()))
return transformArguments(node, context);
else
return transformConstantReference(node, context);
@@ -44,7 +53,14 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext
}
private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) {
- Value value = context.constants().get(node.getName());
+ String name = node.getName();
+ if (node.reference().name().equals("constant")) {
+ ExpressionNode arg = node.getArguments().expressions().get(0);
+ if (arg instanceof ReferenceNode) {
+ name = ((ReferenceNode)arg).getName();
+ }
+ }
+ Value value = context.constants().get(name); // works if "constant(...)" is added
if (value == null || value.type().rank() > 0) {
return node; // not a number constant reference
}