diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-04-13 11:34:41 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-04-13 11:34:41 +0200 |
commit | a68180b364aa4bbb08a8cc0220a6e20fe7b4404d (patch) | |
tree | 52397579824e17163c22a498221a537fe35c5465 /config-model | |
parent | 7e1b7baba3f2f723405985d636089650a521f5d7 (diff) |
Ignore input also listed in intializers when parsing metadata from model
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java | 6 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java | 15 |
2 files changed, 16 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 7f578f07fe3..5d3624cd3d3 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -120,11 +120,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } - model.getInputMap().forEach((onnxName, onnxInput) -> { - if (model.getInitializers().contains(onnxName)) { - log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName)); - return; - } + model.getInputMap().forEach((__, onnxInput) -> { var reader = new StringReader(onnxInput); try { var asExpression = new RankingExpression(reader); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 7c89a349d7d..1984ceadac6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.logging.Logger; import java.util.stream.Collectors; /** @@ -36,6 +37,8 @@ import java.util.stream.Collectors; */ public class OnnxModelInfo { + private static final Logger log = Logger.getLogger(OnnxModelInfo.class.getName()); + private final ApplicationPackage app; private final String modelPath; private final String defaultOutput; @@ -196,15 +199,27 @@ public class OnnxModelInfo { } static private String onnxModelToJson(Onnx.ModelProto model, Path path) throws IOException { + var initializerNames = model.getGraph().getInitializerList().stream() + .map(Onnx.TensorProto::getName).collect(Collectors.toSet()); ByteArrayOutputStream out = new ByteArrayOutputStream(); JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); g.writeStartObject(); g.writeStringField("path", path.toString()); g.writeArrayFieldStart("inputs"); + int skippedInput = 0; for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { + if (initializerNames.contains(valueInfo.getName())) { + log.fine(() -> "For '%s': skipping name '%s' as it's an initializer" + .formatted(path.getName(), valueInfo.getName())); + ++skippedInput; + continue; + } onnxTypeToJson(g, valueInfo); } + if (skippedInput > 0) + log.info("For '%s': skipped %d inputs that were also listed in initializers" + .formatted(path.getName(), skippedInput)); g.writeEndArray(); g.writeArrayFieldStart("outputs"); |