summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-31 17:37:57 +0200
committerGitHub <noreply@github.com>2023-03-31 17:37:57 +0200
commite0db5db519c291dc9ea9ec994b51fb9499f1e246 (patch)
tree9237dbb8f3c630ecda4e3d59b5ed0c6e29a5f411
parent73c02aff48aedb76453a2a73ac104c1d5a163282 (diff)
parent26736895dfd2dafa5f20f1633a569078e67876c3 (diff)
Merge pull request #26660 from vespa-engine/bjorncs/onnx-model-initializers
Ignore input also listed in initializers
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java7
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java14
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java24
3 files changed, 37 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 272b668b5fb..90a27d1f036 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -1,15 +1,17 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
-import com.yahoo.searchlib.rankingexpression.Reference;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource {
private OnnxModelInfo modelInfo = null;
private final Map<String, String> inputMap = new HashMap<>();
private final Map<String, String> outputMap = new HashMap<>();
+ private final Set<String> initializers = new HashSet<>();
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
@@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource {
for (String onnxName : modelInfo.getOutputs()) {
addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
}
+ initializers.addAll(modelInfo.getInitializers());
this.modelInfo = modelInfo;
}
public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
+ public Set<String> getInitializers() { return Set.copyOf(initializers); }
public String getDefaultOutput() {
return modelInfo != null ? modelInfo.getDefaultOutput() : "";
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index af072c5b59a..7f578f07fe3 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -2,7 +2,6 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
-import com.yahoo.schema.RankProfile;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -12,13 +11,12 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.functions.DynamicTensor;
import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Slice;
import java.io.StringReader;
import java.util.HashSet;
import java.util.Set;
+import java.util.logging.Logger;
/**
* Analyzes expression to figure out what inputs it needs
@@ -27,6 +25,8 @@ import java.util.Set;
*/
public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
+ private static final Logger log = Logger.getLogger(InputRecorder.class.getName());
+
private final Set<String> neededInputs;
private final Set<String> handled = new HashSet<>();
@@ -120,7 +120,11 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
if (model == null) {
throw new IllegalArgumentException("missing onnx model: " + arg);
}
- for (String onnxInput : model.getInputMap().values()) {
+ model.getInputMap().forEach((onnxName, onnxInput) -> {
+ if (model.getInitializers().contains(onnxName)) {
+ log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName));
+ return;
+ }
var reader = new StringReader(onnxInput);
try {
var asExpression = new RankingExpression(reader);
@@ -128,7 +132,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
} catch (ParseException e) {
throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage());
}
- }
+ });
return;
}
neededInputs.add(feature.toString());
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 2742dc59fcd..7c89a349d7d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -42,13 +42,16 @@ public class OnnxModelInfo {
private final Map<String, OnnxTypeInfo> inputs;
private final Map<String, OnnxTypeInfo> outputs;
private final Map<String, TensorType> vespaTypes = new HashMap<>();
+ private final Set<String> initializers;
- private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs,
+ Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) {
this.app = app;
this.modelPath = path;
this.inputs = Collections.unmodifiableMap(inputs);
this.outputs = Collections.unmodifiableMap(outputs);
this.defaultOutput = defaultOutput;
+ this.initializers = Set.copyOf(initializers);
}
public String getModelPath() {
@@ -63,6 +66,8 @@ public class OnnxModelInfo {
return outputs.keySet();
}
+ public Set<String> getInitializers() { return initializers; }
+
public String getDefaultOutput() {
return defaultOutput;
}
@@ -208,6 +213,14 @@ public class OnnxModelInfo {
}
g.writeEndArray();
+ g.writeArrayFieldStart("initializers");
+ for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) {
+ g.writeStartObject();
+ g.writeStringField("name", initializers.getName());
+ g.writeEndObject();
+ }
+ g.writeEndArray();
+
g.writeEndObject();
g.close();
return out.toString();
@@ -218,6 +231,7 @@ public class OnnxModelInfo {
JsonNode root = m.readTree(json);
Map<String, OnnxTypeInfo> inputs = new HashMap<>();
Map<String, OnnxTypeInfo> outputs = new HashMap<>();
+ Set<String> initializers = new HashSet<>();
String defaultOutput = "";
String path = null;
@@ -233,7 +247,13 @@ public class OnnxModelInfo {
if (root.get("outputs").has(0)) {
defaultOutput = root.get("outputs").get(0).get("name").textValue();
}
- return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput);
+ var initializerRoot = root.get("initializers");
+ if (initializerRoot != null) {
+ for (JsonNode initializer : initializerRoot) {
+ initializers.add(initializer.get("name").textValue());
+ }
+ }
+ return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput);
}
static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {