summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java20
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java87
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java2
3 files changed, 106 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index c1c0ad4f044..6d5202164dc 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -15,6 +15,7 @@ import com.yahoo.schema.document.ImmutableSDField;
import com.yahoo.schema.document.SDDocumentType;
import com.yahoo.schema.expressiontransforms.ExpressionTransforms;
import com.yahoo.schema.expressiontransforms.RankProfileTransformContext;
+import com.yahoo.schema.expressiontransforms.InputRecorder;
import com.yahoo.schema.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.FeatureList;
@@ -876,7 +877,7 @@ public class RankProfile implements Cloneable {
}
}
- private Map<String, RankingExpressionFunction> gatherAllFunctions() {
+ private Map<String, RankingExpressionFunction> gatherAllFunctions() {
if (functions.isEmpty() && inherited().isEmpty()) return Map.of();
if (inherited().isEmpty()) return Collections.unmodifiableMap(new LinkedHashMap<>(functions));
@@ -1006,6 +1007,21 @@ public class RankProfile implements Cloneable {
// TODO: This merges all functions from inherited profiles too and erases inheritance information. Not good.
functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms);
allFunctionsCached = null;
+
+ if (globalPhaseRanking != null) {
+ var context = new RankProfileTransformContext(this,
+ queryProfiles,
+ featureTypes,
+ importedModels,
+ constants(),
+ inlineFunctions);
+ var needInputs = new HashSet<String>();
+ var recorder = new InputRecorder(needInputs);
+ recorder.transform(globalPhaseRanking.function().getBody(), context);
+ for (String input : needInputs) {
+ System.err.println("need input => " + input);
+ }
+ }
}
private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<Reference, Constant> constants) {
@@ -1102,7 +1118,7 @@ public class RankProfile implements Cloneable {
for (FieldDescription field : queryProfileType.declaredFields().values()) {
TensorType type = field.getType().asTensorType();
Optional<Reference> feature = Reference.simple(field.getName());
- if ( feature.isEmpty() || ! feature.get().name().equals("query")) continue;
+ if (feature.isEmpty() || ! feature.get().name().equals("query")) continue;
if (featureTypes.containsKey(feature.get())) continue; // Explicit feature types (from inputs) overrides
TensorType existingType = context.getType(feature.get());
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
new file mode 100644
index 00000000000..d2cbdda959f
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -0,0 +1,87 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.expressiontransforms;
+
+import com.yahoo.schema.RankProfile;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+
+import java.util.Set;
+
+/**
+ * Analyzes expression to figure out what inputs it needs
+ *
+ * @author arnej
+ */
+public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> {
+
+ private final Set<String> neededInputs;
+
+ public InputRecorder(Set<String> target) {
+ this.neededInputs = target;
+ }
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode r) {
+ handle(r, context);
+ return node;
+ }
+ if (node instanceof CompositeNode c)
+ return transformChildren(c, context);
+ if (node instanceof ConstantNode) {
+ return node;
+ }
+ throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]");
+ }
+
+ private void handle(ReferenceNode feature, RankProfileTransformContext context) {
+ String name = feature.getName();
+ var args = feature.getArguments();
+ if (args.size() == 0) {
+ var f = context.rankProfile().getFunctions().get(name);
+ if (f != null && f.function().arguments().size() == 0) {
+ transform(f.function().getBody().getRoot(), context);
+ return;
+ }
+ neededInputs.add(feature.toString());
+ return;
+ }
+ if (args.size() == 1) {
+ if ("value".equals(name)) {
+ transform(args.expressions().get(0), context);
+ return;
+ }
+ if ("attribute".equals(name) || "query".equals(name)) {
+ neededInputs.add(feature.toString());
+ return;
+ }
+ if ("constant".equals(name)) {
+ var allConstants = context.rankProfile().constants();
+ if (allConstants.containsKey(feature.reference())) {
+ neededInputs.add(feature.toString());
+ return;
+ }
+ throw new IllegalArgumentException("unknown constant: " + feature);
+ }
+ }
+ if ("onnx".equals(name)) {
+ if (args.size() != 1) {
+ throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature);
+ }
+ var arg = args.expressions().get(0);
+ var models = context.rankProfile().onnxModels();
+ var model = models.get(arg.toString());
+ if (model == null) {
+ throw new IllegalArgumentException("missing onnx model: " + arg);
+ }
+ for (String onnxInput : model.getInputMap().values()) {
+ neededInputs.add(onnxInput);
+ }
+ return;
+ }
+ throw new IllegalArgumentException("cannot handle feature: " + feature);
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index 5e8bfc245a7..db14c66ed6d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -88,7 +88,7 @@ public final class FunctionNode extends CompositeNode {
@Override
public Value evaluate(Context context) {
if (arguments.expressions().size() == 0)
- return DoubleValue.zero.function(function ,DoubleValue.zero);
+ return DoubleValue.zero.function(function, DoubleValue.zero);
Value argument1 = arguments.expressions().get(0).evaluate(context);
if (arguments.expressions().size() == 1)