aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-04-13 11:34:41 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-04-13 11:34:41 +0200
commita68180b364aa4bbb08a8cc0220a6e20fe7b4404d (patch)
tree52397579824e17163c22a498221a537fe35c5465 /config-model
parent7e1b7baba3f2f723405985d636089650a521f5d7 (diff)
Ignore input also listed in intializers when parsing metadata from model
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java6
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java15
2 files changed, 16 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index 7f578f07fe3..5d3624cd3d3 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -120,11 +120,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
if (model == null) {
throw new IllegalArgumentException("missing onnx model: " + arg);
}
- model.getInputMap().forEach((onnxName, onnxInput) -> {
- if (model.getInitializers().contains(onnxName)) {
- log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName));
- return;
- }
+ model.getInputMap().forEach((__, onnxInput) -> {
var reader = new StringReader(onnxInput);
try {
var asExpression = new RankingExpression(reader);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 7c89a349d7d..1984ceadac6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -23,6 +23,7 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
@@ -36,6 +37,8 @@ import java.util.stream.Collectors;
*/
public class OnnxModelInfo {
+ private static final Logger log = Logger.getLogger(OnnxModelInfo.class.getName());
+
private final ApplicationPackage app;
private final String modelPath;
private final String defaultOutput;
@@ -196,15 +199,27 @@ public class OnnxModelInfo {
}
static private String onnxModelToJson(Onnx.ModelProto model, Path path) throws IOException {
+ var initializerNames = model.getGraph().getInitializerList().stream()
+ .map(Onnx.TensorProto::getName).collect(Collectors.toSet());
ByteArrayOutputStream out = new ByteArrayOutputStream();
JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
g.writeStartObject();
g.writeStringField("path", path.toString());
g.writeArrayFieldStart("inputs");
+ int skippedInput = 0;
for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
+ if (initializerNames.contains(valueInfo.getName())) {
+ log.fine(() -> "For '%s': skipping name '%s' as it's an initializer"
+ .formatted(path.getName(), valueInfo.getName()));
+ ++skippedInput;
+ continue;
+ }
onnxTypeToJson(g, valueInfo);
}
+ if (skippedInput > 0)
+ log.info("For '%s': skipped %d inputs that were also listed in initializers"
+ .formatted(path.getName(), skippedInput));
g.writeEndArray();
g.writeArrayFieldStart("outputs");