summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-16 11:00:36 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-16 11:05:35 +0000
commitd883fa1f9fff36baf7231e6e9f8958017a705657 (patch)
treea99d74c943de8eb2d0b6c47746e39b612aa4c5c6 /model-evaluation
parent9ec2f1a8888610e961ba4c6894abb096a8373850 (diff)
* make InputRecorder handle slice with embedded expressions
* special handling for Generate and Slice in BindingExtractor
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java28
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java2
2 files changed, 29 insertions, 1 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
index 6b1f60df6f4..126e9f9f4e6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/BindingExtractor.java
@@ -7,6 +7,9 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Slice;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
@@ -45,6 +48,11 @@ class BindingExtractor {
arguments.addAll(other.arguments);
onnxModelsInUse.putAll(other.onnxModelsInUse);
}
+
+ void removeTarget(String name) {
+ bindTargets.remove(name);
+ arguments.remove(name);
+ }
}
private final Map<FunctionReference, FunctionInfo> functionsInfo = new LinkedHashMap<>();
@@ -83,6 +91,26 @@ class BindingExtractor {
}
return result;
}
+ else if (node instanceof TensorFunctionNode tfn) {
+ for (ExpressionNode child : tfn.children()) {
+ result.merge(extractBindTargets(child));
+ }
+ var f = tfn.function();
+ if (f instanceof Generate) {
+ var tt = f.type(null);
+ for (var dim : tt.dimensions()) {
+ result.removeTarget(dim.name());
+ }
+ }
+ else if (f instanceof Slice<?> s) {
+ for (var selectorFunc : s.selectorFunctions()) {
+ if (selectorFunc instanceof TensorFunctionNode.ExpressionTensorFunction expr) {
+ result.merge(extractBindTargets(expr.wrappedExpression()));
+ }
+ }
+ }
+ return result;
+ }
else if (isOnnx(node)) {
return extractOnnxTargets(node);
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 666c3a103b5..a34898fe4e9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -109,7 +109,7 @@ class FunctionReference {
/**
* Returns a function reference from the given return type serial form,
- * or empty if the string is not a valid function return typoe serial form
+ * or empty if the string is not a valid function return type serial form
*/
static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) {
Matcher expressionMatcher = returnTypePattern.matcher(serialForm);