diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-16 11:00:36 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-16 11:05:35 +0000 |
commit | d883fa1f9fff36baf7231e6e9f8958017a705657 (patch) | |
tree | a99d74c943de8eb2d0b6c47746e39b612aa4c5c6 /model-evaluation | |
parent | 9ec2f1a8888610e961ba4c6894abb096a8373850 (diff) |
* make InputRecorder handle slice with embedded expressions
* special handling for Generate and Slice in BindingExtractor
Diffstat (limited to 'model-evaluation')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java | 28 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java | 2 |
2 files changed, 29 insertions, 1 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java index 6b1f60df6f4..126e9f9f4e6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java @@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -45,6 +48,11 @@ class BindingExtractor { arguments.addAll(other.arguments); onnxModelsInUse.putAll(other.onnxModelsInUse); } + + void removeTarget(String name) { + bindTargets.remove(name); + arguments.remove(name); + } } private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<>(); @@ -83,6 +91,26 @@ class BindingExtractor { } return result; } + else if (node instanceof TensorFunctionNode tfn) { + for (ExpressionNode child : tfn.children()) { + result.merge(extractBindTargets(child)); + } + var f = tfn.function(); + if (f instanceof Generate) { + var tt = f.type(null); + for (var dim : tt.dimensions()) { + result.removeTarget(dim.name()); + } + } + else if (f instanceof Slice<?> s) { + for (var selectorFunc : s.selectorFunctions()) { + if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) { + result.merge(extractBindTargets(expr.wrappedExpression())); + } + } + } + return result; + } else if (isOnnx(node)) { return extractOnnxTargets(node); } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 666c3a103b5..a34898fe4e9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -109,7 +109,7 @@ class FunctionReference { /** * Returns a function reference from the given return type serial form, - * or empty if the string is not a valid function return typoe serial form + * or empty if the string is not a valid function return type serial form */ static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) { Matcher expressionMatcher = returnTypePattern.matcher(serialForm); |