From 26736895dfd2dafa5f20f1633a569078e67876c3 Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Fri, 31 Mar 2023 12:54:02 +0200 Subject: Ignore input also listed in initializers Ignore inputs which is already initialized when determining required features as input to ONNX models in global-phase. --- .../src/main/java/com/yahoo/schema/OnnxModel.java | 7 ++++++- .../schema/expressiontransforms/InputRecorder.java | 14 ++++++++----- .../com/yahoo/vespa/model/ml/OnnxModelInfo.java | 24 ++++++++++++++++++++-- 3 files changed, 37 insertions(+), 8 deletions(-) (limited to 'config-model/src/main/java') diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 272b668b5fb..90a27d1f036 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,15 +1,17 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.ml.OnnxModelInfo; -import com.yahoo.searchlib.rankingexpression.Reference; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource { private OnnxModelInfo modelInfo = null; private final Map inputMap = new HashMap<>(); private final Map outputMap = new HashMap<>(); + private final Set initializers = new HashSet<>(); private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; @@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource { for (String onnxName : modelInfo.getOutputs()) { addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); } + initializers.addAll(modelInfo.getInitializers()); this.modelInfo = modelInfo; } public Map getInputMap() { return Collections.unmodifiableMap(inputMap); } public Map getOutputMap() { return Collections.unmodifiableMap(outputMap); } + public Set getInitializers() { return Set.copyOf(initializers); } public String getDefaultOutput() { return modelInfo != null ? modelInfo.getDefaultOutput() : ""; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index af072c5b59a..7f578f07fe3 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -2,7 +2,6 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; -import com.yahoo.schema.RankProfile; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -12,13 +11,12 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.functions.DynamicTensor; import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Slice; import java.io.StringReader; import java.util.HashSet; import java.util.Set; +import java.util.logging.Logger; /** * Analyzes expression to figure out what inputs it needs @@ -27,6 +25,8 @@ import java.util.Set; */ public class InputRecorder extends ExpressionTransformer { + private static final Logger log = Logger.getLogger(InputRecorder.class.getName()); + private final Set neededInputs; private final Set handled = new HashSet<>(); @@ -120,7 +120,11 @@ public class InputRecorder extends ExpressionTransformer { if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } - for (String onnxInput : model.getInputMap().values()) { + model.getInputMap().forEach((onnxName, onnxInput) -> { + if (model.getInitializers().contains(onnxName)) { + log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName)); + return; + } var reader = new StringReader(onnxInput); try { var asExpression = new RankingExpression(reader); @@ -128,7 +132,7 @@ public class InputRecorder extends ExpressionTransformer { } catch (ParseException e) { throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage()); } - } + }); return; } neededInputs.add(feature.toString()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 2742dc59fcd..7c89a349d7d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -42,13 +42,16 @@ public class OnnxModelInfo { private final Map inputs; private final Map outputs; private final Map vespaTypes = new HashMap<>(); + private final Set initializers; - private OnnxModelInfo(ApplicationPackage app, String path, Map inputs, Map outputs, String defaultOutput) { + private OnnxModelInfo(ApplicationPackage app, String path, Map inputs, + Map outputs, Set initializers, String defaultOutput) { this.app = app; this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); this.defaultOutput = defaultOutput; + this.initializers = Set.copyOf(initializers); } public String getModelPath() { @@ -63,6 +66,8 @@ public class OnnxModelInfo { return outputs.keySet(); } + public Set getInitializers() { return initializers; } + public String getDefaultOutput() { return defaultOutput; } @@ -208,6 +213,14 @@ public class OnnxModelInfo { } g.writeEndArray(); + g.writeArrayFieldStart("initializers"); + for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) { + g.writeStartObject(); + g.writeStringField("name", initializers.getName()); + g.writeEndObject(); + } + g.writeEndArray(); + g.writeEndObject(); g.close(); return out.toString(); @@ -218,6 +231,7 @@ public class OnnxModelInfo { JsonNode root = m.readTree(json); Map inputs = new HashMap<>(); Map outputs = new HashMap<>(); + Set initializers = new HashSet<>(); String defaultOutput = ""; String path = null; @@ -233,7 +247,13 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); + var initializerRoot = root.get("initializers"); + if (initializerRoot != null) { + for (JsonNode initializer : initializerRoot) { + initializers.add(initializer.get("name").textValue()); + } + } + return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { -- cgit v1.2.3