diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-14 22:34:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-14 22:34:07 +0100 |
commit | 772568949a67767f06784e8e9350c487f36bf2e2 (patch) | |
tree | b5b6d8caa8df54f5e09f8701a4a552b98f4b2817 /config-model/src/main/java/com | |
parent | f5ac4a5022cd3768d423fecc1632e2735cf083f5 (diff) | |
parent | b99705e43c8e30fba0406ae450034edcfdc9ab52 (diff) |
Merge pull request #26437 from vespa-engine/arnej/handle-tensor-generate
add special handling of TensorFunctionNode containing Generate function
Diffstat (limited to 'config-model/src/main/java/com')
3 files changed, 53 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index a00bbb682a8..f9b3bc77040 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1017,7 +1017,7 @@ public class RankProfile implements Cloneable { inlineFunctions); var needInputs = new HashSet<String>(); var recorder = new InputRecorder(needInputs); - recorder.transform(globalPhaseRanking.function().getBody(), context); + recorder.process(globalPhaseRanking.function().getBody(), context); for (String input : needInputs) { if (input.startsWith("constant(") || input.startsWith("query(")) { continue; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 7124628be0c..6e7787c2dd1 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -10,7 +10,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.functions.Generate; import java.io.StringReader; import java.util.HashSet; @@ -21,7 +23,7 @@ import java.util.Set; * * @author arnej */ -public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> { +public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); @@ -30,12 +32,28 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon this.neededInputs = target; } + public void process(RankingExpression expression, RankProfileTransformContext context) { + transform(expression.getRoot(), new InputRecorderContext(context)); + } + @Override - public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) { if (node instanceof ReferenceNode r) { handle(r, context); return node; } + if (node instanceof TensorFunctionNode t) { + var f = t.function(); + if (f instanceof Generate) { + var childContext = new InputRecorderContext(context); + var tt = f.type(context.types()); + // expects only indexed dimensions, should we check? + for (var dim : tt.dimensions()) { + childContext.localVariables().add(dim.name()); + } + return transformChildren(t, childContext); + } + } if (node instanceof CompositeNode c) return transformChildren(c, context); if (node instanceof ConstantNode) { @@ -44,11 +62,14 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]"); } - private void handle(ReferenceNode feature, RankProfileTransformContext context) { + private void handle(ReferenceNode feature, InputRecorderContext context) { Reference ref = feature.reference(); String name = ref.name(); var args = ref.arguments(); boolean simpleFunctionOrIdentifier = (args.size() == 0) && (ref.output() == null); + if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) { + return; + } if (ref.isSimpleRankingExpressionWrapper()) { name = ref.simpleArgument().get(); simpleFunctionOrIdentifier = true; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java new file mode 100644 index 00000000000..54617374b67 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java @@ -0,0 +1,28 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.expressiontransforms; + +import com.yahoo.schema.RankProfile; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; + +import java.util.HashSet; +import java.util.Set; + +class InputRecorderContext extends TransformContext { + + private final RankProfileTransformContext parent; + private final Set<String> localVariables = new HashSet<>(); + + public RankProfile rankProfile() { return parent.rankProfile(); } + public Set<String> localVariables() { return localVariables; } + + public InputRecorderContext(RankProfileTransformContext parent) { + super(parent.constants(), parent.types()); + this.parent = parent; + } + + public InputRecorderContext(InputRecorderContext parent) { + super(parent.constants(), parent.types()); + this.parent = parent.parent; + this.localVariables.addAll(parent.localVariables); + } +} |