aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-02-09 22:21:04 +0100
committerGitHub <noreply@github.com>2023-02-09 22:21:04 +0100
commit6ff34fa24a55630e61ca56b0cd0c299ed0f5bb53 (patch)
tree952e991b69cca00928610b4c1d06958aec2a5e2a /config-model/src/main/java/com/yahoo/schema
parent8b6f2a6f7ce171833dcc139e46f5c1f9cab236e5 (diff)
parent7ec416403481c47903ed36a44ce901c5a1c61a9e (diff)
Merge pull request #25927 from vespa-engine/arnej/xp-input-recorder
add InputRecorder for global-phase expressions
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java24
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java92
2 files changed, 114 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index c1c0ad4f044..ad6eb038058 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -15,6 +15,7 @@ import com.yahoo.schema.document.ImmutableSDField;
import com.yahoo.schema.document.SDDocumentType;
import com.yahoo.schema.expressiontransforms.ExpressionTransforms;
import com.yahoo.schema.expressiontransforms.RankProfileTransformContext;
+import com.yahoo.schema.expressiontransforms.InputRecorder;
import com.yahoo.schema.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.FeatureList;
@@ -876,7 +877,7 @@ public class RankProfile implements Cloneable {
}
}
- private Map<String, RankingExpressionFunction> gatherAllFunctions() {
+ private Map<String, RankingExpressionFunction> gatherAllFunctions() {
if (functions.isEmpty() && inherited().isEmpty()) return Map.of();
if (inherited().isEmpty()) return Collections.unmodifiableMap(new LinkedHashMap<>(functions));
@@ -1006,6 +1007,25 @@ public class RankProfile implements Cloneable {
// TODO: This merges all functions from inherited profiles too and erases inheritance information. Not good.
functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms);
allFunctionsCached = null;
+
+ if (globalPhaseRanking != null) {
+ var context = new RankProfileTransformContext(this,
+ queryProfiles,
+ featureTypes,
+ importedModels,
+ constants(),
+ inlineFunctions);
+ var needInputs = new HashSet<String>();
+ var recorder = new InputRecorder(needInputs);
+ recorder.transform(globalPhaseRanking.function().getBody(), context);
+ for (String input : needInputs) {
+ try {
+ addMatchFeatures(new FeatureList(input));
+ } catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) {
+ throw new IllegalArgumentException("invalid input in global-phase expression: "+input);
+ }
+ }
+ }
}
private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<Reference, Constant> constants) {
@@ -1102,7 +1122,7 @@ public class RankProfile implements Cloneable {
for (FieldDescription field : queryProfileType.declaredFields().values()) {
TensorType type = field.getType().asTensorType();
Optional<Reference> feature = Reference.simple(field.getName());
- if ( feature.isEmpty() || ! feature.get().name().equals("query")) continue;
+ if (feature.isEmpty() || ! feature.get().name().equals("query")) continue;
if (featureTypes.containsKey(feature.get())) continue; // Explicit feature types (from inputs) overrides
TensorType existingType = context.getType(feature.get());
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
new file mode 100644
index 00000000000..4e7988a2006
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -0,0 +1,92 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.expressiontransforms;
+
+import com.yahoo.schema.FeatureNames;
+import com.yahoo.schema.RankProfile;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+
+import java.util.Set;
+
+/**
+ * Analyzes expression to figure out what inputs it needs
+ *
+ * @author arnej
+ */
+public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> {
+
+ private final Set<String> neededInputs;
+
+ public InputRecorder(Set<String> target) {
+ this.neededInputs = target;
+ }
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode r) {
+ handle(r, context);
+ return node;
+ }
+ if (node instanceof CompositeNode c)
+ return transformChildren(c, context);
+ if (node instanceof ConstantNode) {
+ return node;
+ }
+ throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]");
+ }
+
+ private void handle(ReferenceNode feature, RankProfileTransformContext context) {
+ Reference ref = feature.reference();
+ String name = ref.name();
+ var args = ref.arguments();
+ if (args.size() == 0) {
+ var f = context.rankProfile().getFunctions().get(name);
+ if (f != null && f.function().arguments().size() == 0) {
+ transform(f.function().getBody().getRoot(), context);
+ return;
+ }
+ neededInputs.add(feature.toString());
+ return;
+ }
+ if (args.size() == 1) {
+ if (FeatureNames.isAttributeFeature(ref)) {
+ neededInputs.add(feature.toString());
+ return;
+ }
+ if (FeatureNames.isQueryFeature(ref)) {
+ // get rid of this later, we should be able
+ // to get it from the query
+ neededInputs.add(feature.toString());
+ return;
+ }
+ if (FeatureNames.isConstantFeature(ref)) {
+ var allConstants = context.rankProfile().constants();
+ if (allConstants.containsKey(ref)) {
+ // assumes we have the constant available during evaluation without any more wiring
+ return;
+ }
+ throw new IllegalArgumentException("unknown constant: " + feature);
+ }
+ }
+ if ("onnx".equals(name)) {
+ if (args.size() != 1) {
+ throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature);
+ }
+ var arg = args.expressions().get(0);
+ var models = context.rankProfile().onnxModels();
+ var model = models.get(arg.toString());
+ if (model == null) {
+ throw new IllegalArgumentException("missing onnx model: " + arg);
+ }
+ for (String onnxInput : model.getInputMap().values()) {
+ neededInputs.add(onnxInput);
+ }
+ return;
+ }
+ throw new IllegalArgumentException("cannot handle feature: " + feature);
+ }
+}