diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java index 6b1f60df6f4..126e9f9f4e6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java @@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -45,6 +48,11 @@ class BindingExtractor { arguments.addAll(other.arguments); onnxModelsInUse.putAll(other.onnxModelsInUse); } + + void removeTarget(String name) { + bindTargets.remove(name); + arguments.remove(name); + } } private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<>(); @@ -83,6 +91,26 @@ class BindingExtractor { } return result; } + else if (node instanceof TensorFunctionNode tfn) { + for (ExpressionNode child : tfn.children()) { + result.merge(extractBindTargets(child)); + } + var f = tfn.function(); + if (f instanceof Generate) { + var tt = f.type(null); + for (var dim : tt.dimensions()) { + result.removeTarget(dim.name()); + } + } + else if (f instanceof Slice<?> s) { + for (var selectorFunc : s.selectorFunctions()) { + if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) { + result.merge(extractBindTargets(expr.wrappedExpression())); + } + } + } + return result; + } else if (isOnnx(node)) { return extractOnnxTargets(node); } |