aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-14 12:54:07 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-14 12:59:16 +0000
commitb99705e43c8e30fba0406ae450034edcfdc9ab52 (patch)
tree19d0aee2414a5d1747c7028e17f4e64cd3cdeabf /config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
parent9553cd2709c0791aa530f5388cf156116e857795 (diff)
add special handling of TensorFunctionNode containing Generate function
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java27
1 files changed, 24 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 7124628be0c..6e7787c2dd1 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -10,7 +10,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.functions.Generate;
import java.io.StringReader;
import java.util.HashSet;
@@ -21,7 +23,7 @@ import java.util.Set;
*
* @author arnej
*/
-public class InputRecorder extends ExpressionTransformer<RankProfileTransformContext> {
+public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
private final Set<String> neededInputs;
private final Set<String> handled = new HashSet<>();
@@ -30,12 +32,28 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon
this.neededInputs = target;
}
+ public void process(RankingExpression expression, RankProfileTransformContext context) {
+ transform(expression.getRoot(), new InputRecorderContext(context));
+ }
+
@Override
- public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ public ExpressionNode transform(ExpressionNode node, InputRecorderContext context) {
if (node instanceof ReferenceNode r) {
handle(r, context);
return node;
}
+ if (node instanceof TensorFunctionNode t) {
+ var f = t.function();
+ if (f instanceof Generate) {
+ var childContext = new InputRecorderContext(context);
+ var tt = f.type(context.types());
+ // expects only indexed dimensions, should we check?
+ for (var dim : tt.dimensions()) {
+ childContext.localVariables().add(dim.name());
+ }
+ return transformChildren(t, childContext);
+ }
+ }
if (node instanceof CompositeNode c)
return transformChildren(c, context);
if (node instanceof ConstantNode) {
@@ -44,11 +62,14 @@ public class InputRecorder extends ExpressionTransformer<RankProfileTransformCon
throw new IllegalArgumentException("Cannot handle node type: "+ node + " [" + node.getClass() + "]");
}
- private void handle(ReferenceNode feature, RankProfileTransformContext context) {
+ private void handle(ReferenceNode feature, InputRecorderContext context) {
Reference ref = feature.reference();
String name = ref.name();
var args = ref.arguments();
boolean simpleFunctionOrIdentifier = (args.size() == 0) && (ref.output() == null);
+ if (simpleFunctionOrIdentifier && context.localVariables().contains(name)) {
+ return;
+ }
if (ref.isSimpleRankingExpressionWrapper()) {
name = ref.simpleArgument().get();
simpleFunctionOrIdentifier = true;