diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-14 12:54:07 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-14 12:59:16 +0000 |
commit | b99705e43c8e30fba0406ae450034edcfdc9ab52 (patch) | |
tree | 19d0aee2414a5d1747c7028e17f4e64cd3cdeabf /config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java | |
parent | 9553cd2709c0791aa530f5388cf156116e857795 (diff) |
add special handling of TensorFunctionNode containing Generate function
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 7124628be0c..6e7787c2dd1 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -10,7 +10,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.functions.Generate; import java.io.StringReader; import java.util.HashSet; @@ -21,7 +23,7 @@ import java.util.Set; * * @author arnej */ -public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> { +public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); @@ -30,12 +32,28 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon this.neededInputs = target; } + public void process(RankingExpression expression, RankProfileTransformContext context) { + transform(expression.getRoot(), new InputRecorderContext(context)); + } + @Override - public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) { if (node instanceof ReferenceNode r) { handle(r, context); return node; } + if (node instanceof TensorFunctionNode t) { + var f = t.function(); + if (f instanceof Generate) { + var childContext = new InputRecorderContext(context); + var tt = f.type(context.types()); + // expects only indexed dimensions, should we check? + for (var dim : tt.dimensions()) { + childContext.localVariables().add(dim.name()); + } + return transformChildren(t, childContext); + } + } if (node instanceof CompositeNode c) return transformChildren(c, context); if (node instanceof ConstantNode) { @@ -44,11 +62,14 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]"); } - private void handle(ReferenceNode feature, RankProfileTransformContext context) { + private void handle(ReferenceNode feature, InputRecorderContext context) { Reference ref = feature.reference(); String name = ref.name(); var args = ref.arguments(); boolean simpleFunctionOrIdentifier = (args.size() == 0) && (ref.output() == null); + if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) { + return; + } if (ref.isSimpleRankingExpressionWrapper()) { name = ref.simpleArgument().get(); simpleFunctionOrIdentifier = true; |