summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-20 12:19:34 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-20 12:22:52 +0000
commitbaa9c48be8732564e00730efe680df26a8f47f4c (patch)
tree9c83c10f0a225cdd5828013b56ffee6556d76a62
parent5791e60dfcd5f83d0e77e45498318eeb3dd33ee3 (diff)
use withTransformedExpressions for wiring
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java15
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java19
2 files changed, 6 insertions, 28 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 84d0997b12e..af072c5b59a 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -55,20 +55,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
}
return transformChildren(t, childContext);
}
- if (f instanceof DynamicTensor d) {
- for (var tf : d.cellGeneratorFunctions()) {
- if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- transform(expr.wrappedExpression(), context);
- }
- }
- }
- if (f instanceof Slice s) {
- for (var tf : s.selectorFunctions()) {
- if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- transform(expr.wrappedExpression(), context);
- }
- }
- }
+ node = t.withTransformedExpressions(expr -> transform(expr, context));
}
if (node instanceof CompositeNode c)
return transformChildren(c, context);
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
index bc71e51655c..8cdd5387acd 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
@@ -96,6 +96,11 @@ class BindingExtractor {
for (ExpressionNode child : tfn.children()) {
result.merge(extractBindTargets(child));
}
+ // ignore return value:
+ tfn.withTransformedExpressions(expr -> {
+ result.merge(extractBindTargets(expr));
+ return expr;
+ });
var f = tfn.function();
if (f instanceof Generate) {
var tt = f.type(null);
@@ -103,20 +108,6 @@ class BindingExtractor {
result.removeTarget(dim.name());
}
}
- else if (f instanceof Slice<?> s) {
- for (var selectorFunc : s.selectorFunctions()) {
- if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- result.merge(extractBindTargets(expr.wrappedExpression()));
- }
- }
- }
- else if (f instanceof DynamicTensor<?> d) {
- for (var tf : d.cellGeneratorFunctions()) {
- if (tf instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
- result.merge(extractBindTargets(expr.wrappedExpression()));
- }
- }
- }
return result;
}
else if (isOnnx(node)) {