diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-04-14 12:23:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-14 12:23:25 +0200 |
commit | d722e77bdf0f135ba39b53930abf1fa8f933752d (patch) | |
tree | 2bcbabd37fec4654386c378193007d482213ed8e | |
parent | b66f35888ad413800ac16f841582da5bf067cb7f (diff) | |
parent | a68180b364aa4bbb08a8cc0220a6e20fe7b4404d (diff) |
Merge pull request #26737 from vespa-engine/bjorncs/onnx-model-initializers
Ignore input also listed in intializers when parsing metadata from model
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java | 6 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java | 15 |
2 files changed, 16 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index 7f578f07fe3..5d3624cd3d3 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -120,11 +120,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } - model.getInputMap().forEach((onnxName, onnxInput) -> { - if (model.getInitializers().contains(onnxName)) { - log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName)); - return; - } + model.getInputMap().forEach((__, onnxInput) -> { var reader = new StringReader(onnxInput); try { var asExpression = new RankingExpression(reader); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 7c89a349d7d..1984ceadac6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.logging.Logger; import java.util.stream.Collectors; /** @@ -36,6 +37,8 @@ import java.util.stream.Collectors; */ public class OnnxModelInfo { + private static final Logger log = Logger.getLogger(OnnxModelInfo.class.getName()); + private final ApplicationPackage app; private final String modelPath; private final String defaultOutput; @@ -196,15 +199,27 @@ public class OnnxModelInfo { } static private String onnxModelToJson(Onnx.ModelProto model, Path path) throws IOException { + var initializerNames = model.getGraph().getInitializerList().stream() + .map(Onnx.TensorProto::getName).collect(Collectors.toSet()); ByteArrayOutputStream out = new ByteArrayOutputStream(); JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); g.writeStartObject(); g.writeStringField("path", path.toString()); g.writeArrayFieldStart("inputs"); + int skippedInput = 0; for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { + if (initializerNames.contains(valueInfo.getName())) { + log.fine(() -> "For '%s': skipping name '%s' as it's an initializer" + .formatted(path.getName(), valueInfo.getName())); + ++skippedInput; + continue; + } onnxTypeToJson(g, valueInfo); } + if (skippedInput > 0) + log.info("For '%s': skipped %d inputs that were also listed in initializers" + .formatted(path.getName(), skippedInput)); g.writeEndArray(); g.writeArrayFieldStart("outputs"); |