diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-14 12:54:07 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-14 12:59:16 +0000 |
commit | b99705e43c8e30fba0406ae450034edcfdc9ab52 (patch) | |
tree | 19d0aee2414a5d1747c7028e17f4e64cd3cdeabf /config-model/src/main/java/com/yahoo/schema | |
parent | 9553cd2709c0791aa530f5388cf156116e857795 (diff) |
add special handling of TensorFunctionNode containing Generate function
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema')
3 files changed, 53 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index a00bbb682a8..f9b3bc77040 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1017,7 +1017,7 @@ public class RankProfile implements Cloneable { inlineFunctions); var needInputs = new HashSet<String>(); var recorder = new InputRecorder(needInputs); - recorder.transform(globalPhaseRanking.function().getBody(), context); + recorder.process(globalPhaseRanking.function().getBody(), context); for (String input : needInputs) { if (input.startsWith("constant(") || input.startsWith("query(")) { continue; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 7124628be0c..6e7787c2dd1 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -10,7 +10,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.functions.Generate; import java.io.StringReader; import java.util.HashSet; @@ -21,7 +23,7 @@ import java.util.Set; * * @author arnej */ -public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> { +public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); @@ -30,12 +32,28 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon this.neededInputs = target; } + public void process(RankingExpression expression, RankProfileTransformContext context) { + transform(expression.getRoot(), new InputRecorderContext(context)); + } + @Override - public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) { if (node instanceof ReferenceNode r) { handle(r, context); return node; } + if (node instanceof TensorFunctionNode t) { + var f = t.function(); + if (f instanceof Generate) { + var childContext = new InputRecorderContext(context); + var tt = f.type(context.types()); + // expects only indexed dimensions, should we check? + for (var dim : tt.dimensions()) { + childContext.localVariables().add(dim.name()); + } + return transformChildren(t, childContext); + } + } if (node instanceof CompositeNode c) return transformChildren(c, context); if (node instanceof ConstantNode) { @@ -44,11 +62,14 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]"); } - private void handle(ReferenceNode feature, RankProfileTransformContext context) { + private void handle(ReferenceNode feature, InputRecorderContext context) { Reference ref = feature.reference(); String name = ref.name(); var args = ref.arguments(); boolean simpleFunctionOrIdentifier = (args.size() == 0) && (ref.output() == null); + if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) { + return; + } if (ref.isSimpleRankingExpressionWrapper()) { name = ref.simpleArgument().get(); simpleFunctionOrIdentifier = true; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java new file mode 100644 index 00000000000..54617374b67 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorderContext.java @@ -0,0 +1,28 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.schema.expressiontransforms; + +import com.yahoo.schema.RankProfile; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; + +import java.util.HashSet; +import java.util.Set; + +class InputRecorderContext extends TransformContext { + + private final RankProfileTransformContext parent; + private final Set<String> localVariables = new HashSet<>(); + + public RankProfile rankProfile() { return parent.rankProfile(); } + public Set<String> localVariables() { return localVariables; } + + public InputRecorderContext(RankProfileTransformContext parent) { + super(parent.constants(), parent.types()); + this.parent = parent; + } + + public InputRecorderContext(InputRecorderContext parent) { + super(parent.constants(), parent.types()); + this.parent = parent.parent; + this.localVariables.addAll(parent.localVariables); + } +} |