diff options
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java | 15 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java | 19 |
2 files changed, 6 insertions, 28 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 84d0997b12e..af072c5b59a 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -55,20 +55,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } return transformChildren(t, childContext); } - if (f instanceof DynamicTensor d) { - for (var tf : d.cellGeneratorFunctions()) { - if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - transform(expr.wrappedExpression(), context); - } - } - } - if (f instanceof Slice s) { - for (var tf : s.selectorFunctions()) { - if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - transform(expr.wrappedExpression(), context); - } - } - } + node = t.withTransformedExpressions(expr -> transform(expr, context)); } if (node instanceof CompositeNode c) return transformChildren(c, context); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java index bc71e51655c..8cdd5387acd 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java @@ -96,6 +96,11 @@ class BindingExtractor { for (ExpressionNode child : tfn.children()) { result.merge(extractBindTargets(child)); } + // ignore return value: + tfn.withTransformedExpressions(expr -> { + result.merge(extractBindTargets(expr)); + return expr; + }); var f = tfn.function(); if (f instanceof Generate) { var tt = f.type(null); @@ -103,20 +108,6 @@ class BindingExtractor { result.removeTarget(dim.name()); } } - else if (f instanceof Slice<?> s) { - for (var selectorFunc : s.selectorFunctions()) { - if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - result.merge(extractBindTargets(expr.wrappedExpression())); - } - } - } - else if (f instanceof DynamicTensor<?> d) { - for (var tf : d.cellGeneratorFunctions()) { - if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) { - result.merge(extractBindTargets(expr.wrappedExpression())); - } - } - } return result; } else if (isOnnx(node)) { |