diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-02-09 22:21:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-09 22:21:04 +0100 |
commit | 6ff34fa24a55630e61ca56b0cd0c299ed0f5bb53 (patch) | |
tree | 952e991b69cca00928610b4c1d06958aec2a5e2a /config-model/src/main | |
parent | 8b6f2a6f7ce171833dcc139e46f5c1f9cab236e5 (diff) | |
parent | 7ec416403481c47903ed36a44ce901c5a1c61a9e (diff) |
Merge pull request #25927 from vespa-engine/arnej/xp-input-recorder
add InputRecorder for global-phase expressions
Diffstat (limited to 'config-model/src/main')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/RankProfile.java | 24 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java | 92 |
2 files changed, 114 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index c1c0ad4f044..ad6eb038058 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -15,6 +15,7 @@ import com.yahoo.schema.document.ImmutableSDField; import com.yahoo.schema.document.SDDocumentType; import com.yahoo.schema.expressiontransforms.ExpressionTransforms; import com.yahoo.schema.expressiontransforms.RankProfileTransformContext; +import com.yahoo.schema.expressiontransforms.InputRecorder; import com.yahoo.schema.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; @@ -876,7 +877,7 @@ public class RankProfile implements Cloneable { } } - private Map<String, RankingExpressionFunction> gatherAllFunctions() { + private Map<String, RankingExpressionFunction> gatherAllFunctions() { if (functions.isEmpty() && inherited().isEmpty()) return Map.of(); if (inherited().isEmpty()) return Collections.unmodifiableMap(new LinkedHashMap<>(functions)); @@ -1006,6 +1007,25 @@ public class RankProfile implements Cloneable { // TODO: This merges all functions from inherited profiles too and erases inheritance information. Not good. functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); allFunctionsCached = null; + + if (globalPhaseRanking != null) { + var context = new RankProfileTransformContext(this, + queryProfiles, + featureTypes, + importedModels, + constants(), + inlineFunctions); + var needInputs = new HashSet<String>(); + var recorder = new InputRecorder(needInputs); + recorder.transform(globalPhaseRanking.function().getBody(), context); + for (String input : needInputs) { + try { + addMatchFeatures(new FeatureList(input)); + } catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) { + throw new IllegalArgumentException("invalid input in global-phase expression: "+input); + } + } + } } private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<Reference, Constant> constants) { @@ -1102,7 +1122,7 @@ public class RankProfile implements Cloneable { for (FieldDescription field : queryProfileType.declaredFields().values()) { TensorType type = field.getType().asTensorType(); Optional<Reference> feature = Reference.simple(field.getName()); - if ( feature.isEmpty() || ! feature.get().name().equals("query")) continue; + if (feature.isEmpty() || ! feature.get().name().equals("query")) continue; if (featureTypes.containsKey(feature.get())) continue; // Explicit feature types (from inputs) overrides TensorType existingType = context.getType(feature.get()); diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java new file mode 100644 index 00000000000..4e7988a2006 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -0,0 +1,92 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.expressiontransforms; + +import com.yahoo.schema.FeatureNames; +import com.yahoo.schema.RankProfile; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.util.Set; + +/** + * Analyzes expression to figure out what inputs it needs + * + * @author arnej + */ +public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> { + + private final Set<String> neededInputs; + + public InputRecorder(Set<String> target) { + this.neededInputs = target; + } + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode r) { + handle(r, context); + return node; + } + if (node instanceof CompositeNode c) + return transformChildren(c, context); + if (node instanceof ConstantNode) { + return node; + } + throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]"); + } + + private void handle(ReferenceNode feature, RankProfileTransformContext context) { + Reference ref = feature.reference(); + String name = ref.name(); + var args = ref.arguments(); + if (args.size() == 0) { + var f = context.rankProfile().getFunctions().get(name); + if (f != null && f.function().arguments().size() == 0) { + transform(f.function().getBody().getRoot(), context); + return; + } + neededInputs.add(feature.toString()); + return; + } + if (args.size() == 1) { + if (FeatureNames.isAttributeFeature(ref)) { + neededInputs.add(feature.toString()); + return; + } + if (FeatureNames.isQueryFeature(ref)) { + // get rid of this later, we should be able + // to get it from the query + neededInputs.add(feature.toString()); + return; + } + if (FeatureNames.isConstantFeature(ref)) { + var allConstants = context.rankProfile().constants(); + if (allConstants.containsKey(ref)) { + // assumes we have the constant available during evaluation without any more wiring + return; + } + throw new IllegalArgumentException("unknown constant: " + feature); + } + } + if ("onnx".equals(name)) { + if (args.size() != 1) { + throw new IllegalArgumentException("expected name of ONNX model as argument: " + feature); + } + var arg = args.expressions().get(0); + var models = context.rankProfile().onnxModels(); + var model = models.get(arg.toString()); + if (model == null) { + throw new IllegalArgumentException("missing onnx model: " + arg); + } + for (String onnxInput : model.getInputMap().values()) { + neededInputs.add(onnxInput); + } + return; + } + throw new IllegalArgumentException("cannot handle feature: " + feature); + } +} |