diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-31 17:37:57 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-31 17:37:57 +0200 |
commit | e0db5db519c291dc9ea9ec994b51fb9499f1e246 (patch) | |
tree | 9237dbb8f3c630ecda4e3d59b5ed0c6e29a5f411 | |
parent | 73c02aff48aedb76453a2a73ac104c1d5a163282 (diff) | |
parent | 26736895dfd2dafa5f20f1633a569078e67876c3 (diff) |
Merge pull request #26660 from vespa-engine/bjorncs/onnx-model-initializers
Ignore input also listed in initializers
3 files changed, 37 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 272b668b5fb..90a27d1f036 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,15 +1,17 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.ml.OnnxModelInfo; -import com.yahoo.searchlib.rankingexpression.Reference; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource { private OnnxModelInfo modelInfo = null; private final Map<String, String> inputMap = new HashMap<>(); private final Map<String, String> outputMap = new HashMap<>(); + private final Set<String> initializers = new HashSet<>(); private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; @@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource { for (String onnxName : modelInfo.getOutputs()) { addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); } + initializers.addAll(modelInfo.getInitializers()); this.modelInfo = modelInfo; } public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + public Set<String> getInitializers() { return Set.copyOf(initializers); } public String getDefaultOutput() { return modelInfo != null ? modelInfo.getDefaultOutput() : ""; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index af072c5b59a..7f578f07fe3 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -2,7 +2,6 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; -import com.yahoo.schema.RankProfile; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -12,13 +11,12 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.functions.DynamicTensor; import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Slice; import java.io.StringReader; import java.util.HashSet; import java.util.Set; +import java.util.logging.Logger; /** * Analyzes expression to figure out what inputs it needs @@ -27,6 +25,8 @@ import java.util.Set; */ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { + private static final Logger log = Logger.getLogger(InputRecorder.class.getName()); + private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); @@ -120,7 +120,11 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } - for (String onnxInput : model.getInputMap().values()) { + model.getInputMap().forEach((onnxName, onnxInput) -> { + if (model.getInitializers().contains(onnxName)) { + log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName)); + return; + } var reader = new StringReader(onnxInput); try { var asExpression = new RankingExpression(reader); @@ -128,7 +132,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } catch (ParseException e) { throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage()); } - } + }); return; } neededInputs.add(feature.toString()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 2742dc59fcd..7c89a349d7d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -42,13 +42,16 @@ public class OnnxModelInfo { private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); + private final Set<String> initializers; - private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, + Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) { this.app = app; this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); this.defaultOutput = defaultOutput; + this.initializers = Set.copyOf(initializers); } public String getModelPath() { @@ -63,6 +66,8 @@ public class OnnxModelInfo { return outputs.keySet(); } + public Set<String> getInitializers() { return initializers; } + public String getDefaultOutput() { return defaultOutput; } @@ -208,6 +213,14 @@ public class OnnxModelInfo { } g.writeEndArray(); + g.writeArrayFieldStart("initializers"); + for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) { + g.writeStartObject(); + g.writeStringField("name", initializers.getName()); + g.writeEndObject(); + } + g.writeEndArray(); + g.writeEndObject(); g.close(); return out.toString(); @@ -218,6 +231,7 @@ public class OnnxModelInfo { JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); Map<String, OnnxTypeInfo> outputs = new HashMap<>(); + Set<String> initializers = new HashSet<>(); String defaultOutput = ""; String path = null; @@ -233,7 +247,13 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); + var initializerRoot = root.get("initializers"); + if (initializerRoot != null) { + for (JsonNode initializer : initializerRoot) { + initializers.add(initializer.get("name").textValue()); + } + } + return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { |