aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java27
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java219
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java19
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java141
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java97
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java154
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java2
-rw-r--r--config-model/src/main/protobuf/onnx.proto464
13 files changed, 102 insertions, 1046 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index b153ff62e7d..4011ce43841 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -159,12 +158,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
}
- // A reference to an ONNX model?
- Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
- if (onnxFeatureType.isPresent()) {
- return onnxFeatureType.get();
- }
-
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -217,26 +210,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(function);
}
- private Optional<TensorType> onnxFeatureType(Reference reference) {
- if ( ! reference.name().equals("onnxModel"))
- return Optional.empty();
-
- if ( ! featureTypes.containsKey(reference)) {
- String configOrFileName = reference.arguments().expressions().get(0).toString();
-
- // Look up standardized format as added in RankProfile
- String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
- String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
-
- reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
- if ( ! featureTypes.containsKey(reference)) {
- throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
- }
- }
-
- return Optional.of(featureTypes.get(reference));
- }
-
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 58213186f78..c2fb2107604 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -2,22 +2,14 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
-import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
-import onnx.Onnx;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
+import java.util.List;
import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -29,16 +21,16 @@ public class OnnxModel {
public enum PathType {FILE, URI};
private final String name;
- private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private String defaultOutput = null;
- private Map<String, String> inputMap = new HashMap<>();
- private Map<String, String> outputMap = new HashMap<>();
+ private List<OnnxNameMapping> inputMap = new ArrayList<>();
+ private List<OnnxNameMapping> outputMap = new ArrayList<>();
- private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
- private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
- private Map<String, TensorType> vespaTypes = new HashMap<>();
+ public PathType getPathType() {
+ return pathType;
+ }
+
+ private PathType pathType = PathType.FILE;
public OnnxModel(String name) {
this.name = name;
@@ -57,52 +49,21 @@ public class OnnxModel {
}
public void setUri(String uri) {
- throw new IllegalArgumentException("URI for ONNX models are not currently supported");
- }
-
- public PathType getPathType() {
- return pathType;
- }
-
- public void setDefaultOutput(String onnxName) {
- Objects.requireNonNull(onnxName, "Name cannot be null");
- this.defaultOutput = onnxName;
+ Objects.requireNonNull(uri, "uri cannot be null");
+ this.path = uri;
+ this.pathType = PathType.URI;
}
public void addInputNameMapping(String onnxName, String vespaName) {
- addInputNameMapping(onnxName, vespaName, true);
- }
-
- public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! inputMap.containsKey(onnxName)) {
- inputMap.put(onnxName, vespaName);
- }
+ this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
public void addOutputNameMapping(String onnxName, String vespaName) {
- addOutputNameMapping(onnxName, vespaName, true);
- }
-
- public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! outputMap.containsKey(onnxName)) {
- outputMap.put(onnxName, vespaName);
- }
- }
-
- public void addInputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- inputTypes.put(onnxName, type);
- }
-
- public void addOutputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- outputTypes.put(onnxName, type);
+ this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
/** Initiate sending of this constant to some services over file distribution */
@@ -115,16 +76,11 @@ public class OnnxModel {
public String getName() { return name; }
public String getFileName() { return path; }
- public Path getFilePath() { return Path.fromString(path); }
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
- public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
- public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
-
- public String getDefaultOutput() {
- return defaultOutput;
- }
+ public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
+ public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
public void validate() {
if (path == null || path.isEmpty())
@@ -134,142 +90,23 @@ public class OnnxModel {
public String toString() {
StringBuilder b = new StringBuilder();
b.append("onnx-model '").append(name)
- .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
- .append("' with ref '").append(fileReference)
- .append("'");
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
return b.toString();
}
- /**
- * Return the tensor type for an ONNX model for the given context.
- * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
- * type depends on the input types for the given context (rank profile).
- */
- public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
- Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
- if (onnxOutputType == null) {
- throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
- }
- if (containsSymbolicDimensionSizes(onnxOutputType)) {
- return getTensorTypeWithSymbolicDimensions(onnxOutputType, context);
- }
- return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
- }
-
- private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
- Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context);
- if (symbolicSizes.isEmpty()) {
- return TensorType.empty; // Context is probably a rank profile not using this ONNX model
- }
- return typeFrom(onnxOutputType, symbolicSizes);
- }
-
- private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) {
- Map<String, Long> symbolicSizes = new HashMap<>();
- for (String onnxInputName : inputTypes.keySet()) {
-
- Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
- if ( ! containsSymbolicDimensionSizes(onnxType)) {
- continue;
- }
-
- Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
- if (vespaType.isEmpty()) {
- return Collections.emptyMap();
- }
-
- var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
- var vespaDimensions = vespaType.get().dimensions();
- if (vespaDimensions.size() != onnxDimensions.size()) {
- return Collections.emptyMap();
- }
-
- for (int i = 0; i < vespaDimensions.size(); ++i) {
- if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) {
- continue;
- }
- String symbolicName = onnxDimensions.get(i).getDimParam();
- Long size = vespaDimensions.get(i).size().get();
- if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
- throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " +
- "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
- }
- symbolicSizes.put(symbolicName, size);
- }
- }
- return symbolicSizes;
- }
-
- private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
- String source = inputMap.get(onnxInputName);
- if (source != null) {
- // Source is either a simple reference (query/attribute/constant)...
- Optional<Reference> reference = Reference.simple(source);
- if (reference.isPresent()) {
- return Optional.of(context.getType(reference.get()));
- }
- // ... or a function
- ExpressionFunction func = context.getFunction(source);
- if (func != null) {
- return Optional.of(func.getBody().type(context));
- }
- }
- return Optional.empty(); // if this context does not contain this input
- }
-
- private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) {
- return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue());
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type) {
- return typeFrom(type, null);
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) {
- String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- long onnxDimensionSize = onnxDimension.getDimValue();
- if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
- onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
- }
- if (onnxDimensionSize == 0 && symbolicSizes != null) {
- // This is for the case where all symbolic dimensions have
- // different names, but can be resolved to a single dimension size.
- Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
- if (unknownSizes.size() == 1) {
- onnxDimensionSize = unknownSizes.iterator().next();
- }
- }
- if (onnxDimensionSize <= 0) {
- throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
- "ONNX type: " + type + " to Vespa tensor type.");
- }
- builder.indexed(dimensionName, onnxDimensionSize);
- }
- return builder.build();
- }
+ public static class OnnxNameMapping {
+ private String onnxName;
+ private String vespaName;
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT8: return TensorType.Value.FLOAT;
- case INT16: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.FLOAT;
- case UINT8: return TensorType.Value.FLOAT;
- case UINT16: return TensorType.Value.FLOAT;
- case UINT32: return TensorType.Value.FLOAT;
- case UINT64: return TensorType.Value.FLOAT;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
+ private OnnxNameMapping(String onnxName, String vespaName) {
+ this.onnxName = onnxName;
+ this.vespaName = vespaName;
}
+ public String getOnnxName() { return onnxName; }
+ public String getVespaName() { return vespaName; }
+ public void setVespaName(String vespaName) { this.vespaName = vespaName; }
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 96c043bdb34..d309f48d6df 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -18,7 +18,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
@@ -159,10 +158,6 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
- private Map<String, OnnxModel> onnxModels() {
- return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
- }
-
private Stream<ImmutableSDField> allFields() {
if (search == null) return Stream.empty();
if (allFieldsList == null) {
@@ -826,20 +821,6 @@ public class RankProfile implements Cloneable {
}
}
- // Add output types for ONNX models
- for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
- String modelName = entry.getKey();
- OnnxModel model = entry.getValue();
- Arguments args = new Arguments(new ReferenceNode(modelName));
-
- TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
- context.setType(new Reference("onnxModel", args, null), defaultOutputType);
-
- for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
- TensorType type = model.getTensorType(mapping.getKey(), context);
- context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
- }
- }
return context;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 22a32c8fd65..84442fedc48 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
- model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName())));
+ model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName())));
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 56a5d539906..87eaaf0387a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) {
- for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
- String source = mapping.getValue();
+ for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) {
+ String source = mapping.getVespaName();
if (functionNames.contains(source)) {
- onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")");
+ mapping.setVespaName("rankingExpression(" + source + ")");
}
}
}
@@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>();
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
- ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile);
+ ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch());
if (referenceNode != replacedNode) {
replacedSummaryFeatures.add(replacedNode);
i.remove();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index d23a8376e7a..ec517768ea9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature;
+ if ( ! feature.getName().equals("onnx")) return feature;
try {
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index 69cdae10e47..e1ad003e5bd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,36 +1,20 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import com.yahoo.path.Path;
import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.vespa.model.ml.ConvertedModel;
-import com.yahoo.vespa.model.ml.FeatureArguments;
-import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;
/**
- * Transforms ONNX model features of the forms:
- *
- * onnxModel(config_name)
- * onnxModel(config_name).output
- * onnxModel("path/to/model")
- * onnxModel("path/to/model").output
- * onnxModel("path/to/model", "path/to/output")
- * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
- *
- * To the format expected by the backend:
- *
- * onnxModel(config_name).output
+ * Transforms instances of the onnxModel ranking feature and generates
+ * ONNX configuration if necessary.
*
* @author lesters
*/
@@ -49,92 +33,85 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
if (context.rankProfile() == null) return feature;
if (context.rankProfile().getSearch() == null) return feature;
- return transformFeature(feature, context.rankProfile());
+ return transformFeature(feature, context.rankProfile().getSearch());
}
- public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
- ImmutableSearch search = rankProfile.getSearch();
- final String featureName = feature.getName();
- if ( ! featureName.equals("onnxModel")) return feature;
+ public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
+ if (!feature.getName().equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " +
- "onnx-model config or an ONNX file.");
- if (arguments.expressions().size() > 3)
- throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
-
- // Check that the model configuration "onnx-model" exists. If not defined, it should have been added
- // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find
- // the actual ONNX file, which can happen if we are restarting or upgrading an application using an
- // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
-
- String modelConfigName = getModelConfigName(feature.reference());
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
+ throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
+ "onnx-model config or a ONNX file.");
+ if (arguments.expressions().size() > 2)
+ throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
+
+ // Validation that the file actually exists is handled when the file is added to file distribution.
+ // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
+
+ String modelConfigName;
+ OnnxModel onnxModel;
+ if (arguments.expressions().get(0) instanceof ReferenceNode) {
+ modelConfigName = arguments.expressions().get(0).toString();
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
+ }
+ } else if (arguments.expressions().get(0) instanceof ConstantNode) {
String path = asString(arguments.expressions().get(0));
- ModelName modelName = new ModelName(null, Path.fromString(path), true);
- ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile);
- FeatureArguments featureArguments = new FeatureArguments(arguments);
- return convertedModel.expression(featureArguments, null);
- }
-
- String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
- String output = getModelOutput(feature.reference(), defaultOutput);
- if (! onnxModel.getOutputMap().containsValue(output)) {
- throw new IllegalArgumentException(featureName + " argument '" + output +
- "' output not found in model '" + onnxModel.getFileName() + "'");
+ modelConfigName = asValidIdentifier(path);
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ } else {
+ throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
}
- return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
- }
- public static String getModelConfigName(Reference reference) {
- if (reference.arguments().size() > 0) {
- ExpressionNode expr = reference.arguments().expressions().get(0);
- if (expr instanceof ReferenceNode) { // refers to onnx-model config
- return expr.toString();
+ String output = null;
+ if (feature.getOutput() != null) {
+ output = feature.getOutput();
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(output, output);
}
- if (expr instanceof ConstantNode) { // refers to an file path
- return asValidIdentifier(expr);
+ } else if (arguments.expressions().size() > 1) {
+ String name = asString(arguments.expressions().get(1));
+ output = asValidIdentifier(name);
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(name, output);
}
}
- return null;
- }
- public static String getModelOutput(Reference reference, String defaultOutput) {
- if (reference.output() != null) {
- return reference.output();
- } else if (reference.arguments().expressions().size() == 2) {
- return asValidIdentifier(reference.arguments().expressions().get(1));
- } else if (reference.arguments().expressions().size() > 2) {
- return asValidIdentifier(reference.arguments().expressions().get(2));
- }
- return defaultOutput;
+ // Replace feature with name of config
+ ExpressionNode argument = new ReferenceNode(modelConfigName);
+ return new ReferenceNode("onnxModel", List.of(argument), output);
+
}
- public static String stripQuotes(String s) {
- if (isNotQuoteSign(s.codePointAt(0))) return s;
- if (isNotQuoteSign(s.codePointAt(s.length() - 1)))
- throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
- return s.substring(1, s.length()-1);
+ private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
+ return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
}
- public static String asValidIdentifier(String str) {
- return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ private static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
- private static String asValidIdentifier(ExpressionNode node) {
- return asValidIdentifier(asString(node));
+ private static String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
}
- private static boolean isNotQuoteSign(int c) {
- return c != '\'' && c != '"';
+ private static boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
}
- public static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ private static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
deleted file mode 100644
index afba88c135d..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchdefinition.processing;
-
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.application.api.DeployLogger;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankProfileRegistry;
-import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
-import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
-
-import java.util.Map;
-
-/**
- * Processes ONNX ranking features of the form:
- *
- * onnx("files/model.onnx", "path/to/output:1")
- *
- * And generates an "onnx-model" configuration as if it was defined in the schema:
- *
- * onnx-model files_model_onnx {
- * file: "files/model.onnx"
- * }
- *
- * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be
- * processed after this.
- *
- * @author lesters
- */
-public class OnnxModelConfigGenerator extends Processor {
-
- public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
- super(search, deployLogger, rankProfileRegistry, queryProfiles);
- }
-
- @Override
- public void process(boolean validate, boolean documentsOnly) {
- if (documentsOnly) return;
- for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) {
- if (profile.getFirstPhaseRanking() != null) {
- process(profile.getFirstPhaseRanking().getRoot());
- }
- if (profile.getSecondPhaseRanking() != null) {
- process(profile.getSecondPhaseRanking().getRoot());
- }
- for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
- process(function.getValue().function().getBody().getRoot());
- }
- for (ReferenceNode feature : profile.getSummaryFeatures()) {
- process(feature);
- }
- }
- }
-
- private void process(ExpressionNode node) {
- if (node instanceof ReferenceNode) {
- process((ReferenceNode)node);
- } else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode) node).children()) {
- process(child);
- }
- }
- }
-
- private void process(ReferenceNode feature) {
- if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
- if (feature.getArguments().size() > 0) {
- if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
- ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
- String path = OnnxModelTransformer.stripQuotes(node.sourceString());
- String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
-
- // Only add the configuration if the model can actually be found.
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
- path = ApplicationPackage.MODELS_DIR.append(path).toString();
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
- return;
- }
- }
-
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- search.onnxModels().add(onnxModel);
- }
- }
- }
- }
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
deleted file mode 100644
index bead2e7e7c9..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchdefinition.processing;
-
-import com.yahoo.cloud.config.ConfigserverConfig;
-import com.yahoo.component.Version;
-import com.yahoo.config.FileReference;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.application.api.DeployLogger;
-import com.yahoo.config.application.api.FileRegistry;
-import com.yahoo.path.Path;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfileRegistry;
-import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
-import com.yahoo.vespa.defaults.Defaults;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.file.Paths;
-import java.util.Map;
-import java.util.Optional;
-
-/**
- * Processes every "onnx-model" element in the schema. Parses the model file,
- * adds missing input and output mappings (assigning default names), and
- * adds tensor types to all model inputs and outputs.
- *
- * Must be processed before RankingExpressingTypeResolver.
- *
- * @author lesters
- */
-public class OnnxModelTypeResolver extends Processor {
-
- public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
- super(search, deployLogger, rankProfileRegistry, queryProfiles);
- }
-
- @Override
- public void process(boolean validate, boolean documentsOnly) {
- if (documentsOnly) return;
-
- for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
- OnnxModel modelConfig = entry.getValue();
- try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
-
- // Model inputs - if not defined, assumes a function is provided with a valid name
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
- String onnxInputName = valueInfo.getName();
- String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
- modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
- modelConfig.addInputType(onnxInputName, valueInfo.getType());
- }
-
- // Model outputs
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
- String onnxOutputName = valueInfo.getName();
- String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
- modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
- modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
- }
-
- // Set the first output as default
- if ( ! model.getGraph().getOutputList().isEmpty()) {
- modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
- }
-
- } catch (IOException e) {
- throw new IllegalArgumentException("Unable to parse ONNX model", e);
- }
- }
- }
-
- static boolean modelFileExists(String path, ApplicationPackage app) {
- Path pathInApplicationPackage = Path.fromString(path);
- if (getFile(pathInApplicationPackage, app).exists()) {
- return true;
- }
- if (getFileReference(pathInApplicationPackage, app).isPresent()) {
- return true;
- }
- return false;
- }
-
- private InputStream openModelFile(Path path) throws FileNotFoundException {
- ApplicationFile file;
- Optional<FileReference> reference;
- Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
-
- if ((file = getFile(path)).exists()) {
- return file.createInputStream();
- }
- if ((file = getFile(modelsPath)).exists()) {
- return file.createInputStream();
- }
- if ((reference = getFileReference(path)).isPresent()) {
- return openFromFileRepository(path, reference.get());
- }
- if ((reference = getFileReference(modelsPath)).isPresent()) {
- return openFromFileRepository(modelsPath, reference.get());
- }
-
- throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
- "in application package or file repository.");
- }
-
- private ApplicationFile getFile(Path path) {
- return getFile(path, search.applicationPackage());
- }
-
- private static ApplicationFile getFile(Path path, ApplicationPackage app) {
- return app.getFile(path);
- }
-
- private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
- return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
- }
-
- public static String getFileRepositoryPath(Path path, String fileReference) {
- ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
- String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
- return Paths.get(fileRefDir, fileReference, path.getName()).toString();
- }
-
- private Optional<FileReference> getFileReference(Path path) {
- return getFileReference(path, search.applicationPackage());
- }
-
- private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
- Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
- if (fileRegistry.isPresent()) {
- for (FileRegistry.Entry file : fileRegistry.get().export()) {
- if (file.relativePath.equals(path.toString())) {
- return Optional.of(file.reference);
- }
- }
- }
- return Optional.empty();
- }
-
- private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
- if (app == null) return Optional.empty();
- Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
- return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index 1a3ef9e54b4..e8594c2a87f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -74,8 +74,6 @@ public class Processing {
ReferenceFieldsProcessor::new,
FastAccessValidator::new,
ReservedFunctionNames::new,
- OnnxModelConfigGenerator::new,
- OnnxModelTypeResolver::new,
RankingExpressionTypeResolver::new,
// These should be last:
IndexingValidation::new,
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index d5c5183b01f..c6c7969e466 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -1,19 +1,20 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation;
+import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.log.InvalidLogFormatException;
import java.util.logging.Level;
import com.yahoo.log.LogMessage;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver;
import com.yahoo.yolean.Exceptions;
import com.yahoo.system.ProcessExecuter;
import com.yahoo.text.StringUtilities;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.collections.Pair;
import com.yahoo.config.ConfigInstance;
+import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.config.search.ImportedFieldsConfig;
import com.yahoo.vespa.config.search.IndexschemaConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -30,6 +31,7 @@ import com.yahoo.vespa.model.search.SearchCluster;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
+import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.logging.Logger;
@@ -150,9 +152,12 @@ public class RankSetupValidator extends Validator {
// Assist verify-ranksetup in finding the actual ONNX model files
Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap();
if (models.values().size() > 0) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
List<String> config = new ArrayList<>(models.values().size() * 2);
for (OnnxModel model : models.values()) {
- String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference());
+ String modelFilename = Paths.get(model.getFileName()).getFileName().toString();
+ String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString();
config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference()));
config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath));
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 5ee6ed02e61..943fcbf6c1d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -150,7 +150,7 @@ public class ConvertedModel {
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
ExpressionFunction expression = selectExpression(arguments);
- if (sourceModel.isPresent() && context != null) // we should verify
+ if (sourceModel.isPresent()) // we should verify
verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
return expression.getBody().getRoot();
}
diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto
deleted file mode 100644
index dc6542867e0..00000000000
--- a/config-model/src/main/protobuf/onnx.proto
+++ /dev/null
@@ -1,464 +0,0 @@
-//
-// WARNING: This file is automatically generated! Please edit onnx.in.proto.
-//
-
-
-// Copyright (c) Facebook Inc. and Microsoft Corporation.
-// Licensed under the MIT license.
-
-syntax = "proto2";
-
-package onnx;
-
-// Overview
-//
-// ONNX is an open specification that is comprised of the following components:
-//
-// 1) A definition of an extensible computation graph model.
-// 2) Definitions of standard data types.
-// 3) Definitions of built-in operators.
-//
-// This document describes the syntax of models and their computation graphs,
-// as well as the standard data types. Together, they are referred to as the ONNX
-// Intermediate Representation, or 'IR' for short.
-//
-// The normative semantic specification of the ONNX IR is found in docs/IR.md.
-// Definitions of the built-in neural network operators may be found in docs/Operators.md.
-
-// Notes
-//
-// Release
-//
-// We are still in the very early stage of defining ONNX. The current
-// version of ONNX is a starting point. While we are actively working
-// towards a complete spec, we would like to get the community involved
-// by sharing our working version of ONNX.
-//
-// Protobuf compatibility
-//
-// To simplify framework compatibility, ONNX is defined using the subset of protobuf
-// that is compatible with both protobuf v2 and v3. This means that we do not use any
-// protobuf features that are only available in one of the two versions.
-//
-// Here are the most notable contortions we have to carry out to work around
-// these limitations:
-//
-// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
-// of key-value pairs, where order does not matter and duplicates
-// are not allowed.
-
-
-// Versioning
-//
-// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
-//
-// To be compatible with both proto2 and proto3, we will use a version number
-// that is not defined by the default value but an explicit enum number.
-enum Version {
- // proto3 requires the first enum value to be zero.
- // We add this just to appease the compiler.
- _START_VERSION = 0;
- // The version field is always serialized and we will use it to store the
- // version that the graph is generated from. This helps us set up version
- // control. We should use version as
- // xx(major) - xx(minor) - xxxx(bugfix)
- // and we are starting with 0x00000001 (0.0.1), which was the
- // version we published on Oct 10, 2017.
- IR_VERSION_2017_10_10 = 0x00000001;
-
- // IR_VERSION 0.0.2 published on Oct 30, 2017
- // - Added type discriminator to AttributeProto to support proto3 users
- IR_VERSION_2017_10_30 = 0x00000002;
-
- // IR VERSION 0.0.3 published on Nov 3, 2017
- // - For operator versioning:
- // - Added new message OperatorSetIdProto
- // - Added opset_import in ModelProto
- // - For vendor extensions, added domain in NodeProto
- IR_VERSION = 0x00000003;
-}
-
-// Attributes
-//
-// A named attribute containing either singular float, integer, string, graph,
-// and tensor values, or repeated float, integer, string, graph, and tensor values.
-// An AttributeProto MUST contain the name field, and *only one* of the
-// following content fields, effectively enforcing a C/C++ union equivalent.
-message AttributeProto {
-
- // Note: this enum is structurally identical to the OpSchema::AttrType
- // enum defined in schema.h. If you rev one, you likely need to rev the other.
- enum AttributeType {
- UNDEFINED = 0;
- FLOAT = 1;
- INT = 2;
- STRING = 3;
- TENSOR = 4;
- GRAPH = 5;
-
- FLOATS = 6;
- INTS = 7;
- STRINGS = 8;
- TENSORS = 9;
- GRAPHS = 10;
- }
-
- // The name field MUST be present for this version of the IR.
- optional string name = 1; // namespace Attribute
-
- // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
- // In this case, this AttributeProto does not contain data, and it's a reference of attribute
- // in parent scope.
- // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
- optional string ref_attr_name = 21;
-
- // A human-readable documentation for this attribute. Markdown is allowed.
- optional string doc_string = 13;
-
- // The type field MUST be present for this version of the IR.
- // For 0.0.1 versions of the IR, this field was not defined, and
- // implementations needed to use has_field hueristics to determine
- // which value field was in use. For IR_VERSION 0.0.2 or later, this
- // field MUST be set and match the f|i|s|t|... field in use. This
- // change was made to accomodate proto3 implementations.
- optional AttributeType type = 20; // discriminator that indicates which field below is in use
-
- // Exactly ONE of the following fields must be present for this version of the IR
- optional float f = 2; // float
- optional int64 i = 3; // int
- optional bytes s = 4; // UTF-8 string
- optional TensorProto t = 5; // tensor value
- optional GraphProto g = 6; // graph
- // Do not use field below, it's deprecated.
- // optional ValueProto v = 12; // value - subsumes everything but graph
-
- repeated float floats = 7; // list of floats
- repeated int64 ints = 8; // list of ints
- repeated bytes strings = 9; // list of UTF-8 strings
- repeated TensorProto tensors = 10; // list of tensors
- repeated GraphProto graphs = 11; // list of graph
-}
-
-// Defines information on value, including the name, the type, and
-// the shape of the value.
-message ValueInfoProto {
- // This field MUST be present in this version of the IR.
- optional string name = 1; // namespace Value
- // This field MUST be present in this version of the IR.
- optional TypeProto type = 2;
- // A human-readable documentation for this value. Markdown is allowed.
- optional string doc_string = 3;
-}
-
-// Nodes
-//
-// Computation graphs are made up of a DAG of nodes, which represent what is
-// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
-//
-// For example, it can be a node of type "Conv" that takes in an image, a filter
-// tensor and a bias tensor, and produces the convolved output.
-message NodeProto {
- repeated string input = 1; // namespace Value
- repeated string output = 2; // namespace Value
-
- // An optional identifier for this node in a graph.
- // This field MAY be absent in ths version of the IR.
- optional string name = 3; // namespace Node
-
- // The symbolic identifier of the Operator to execute.
- optional string op_type = 4; // namespace Operator
- // The domain of the OperatorSet that specifies the operator named by op_type.
- optional string domain = 7; // namespace Domain
-
- // Additional named attributes.
- repeated AttributeProto attribute = 5;
-
- // A human-readable documentation for this node. Markdown is allowed.
- optional string doc_string = 6;
-}
-
-// Models
-//
-// ModelProto is a top-level file/container format for bundling a ML model and
-// associating its computation graph with metadata.
-//
-// The semantics of the model are described by the associated GraphProto.
-message ModelProto {
- // The version of the IR this model targets. See Version enum above.
- // This field MUST be present.
- optional int64 ir_version = 1;
-
- // The OperatorSets this model relies on.
- // All ModelProtos MUST have at least one entry that
- // specifies which version of the ONNX OperatorSet is
- // being imported.
- //
- // All nodes in the ModelProto's graph will bind against the operator
- // with the same-domain/same-op_type operator with the HIGHEST version
- // in the referenced operator sets.
- repeated OperatorSetIdProto opset_import = 8;
-
- // The name of the framework or tool used to generate this model.
- // This field SHOULD be present to indicate which implementation/tool/framework
- // emitted the model.
- optional string producer_name = 2;
-
- // The version of the framework or tool used to generate this model.
- // This field SHOULD be present to indicate which implementation/tool/framework
- // emitted the model.
- optional string producer_version = 3;
-
- // Domain name of the model.
- // We use reverse domain names as name space indicators. For example:
- // `com.facebook.fair` or `com.microsoft.cognitiveservices`
- //
- // Together with `model_version` and GraphProto.name, this forms the unique identity of
- // the graph.
- optional string domain = 4;
-
- // The version of the graph encoded. See Version enum below.
- optional int64 model_version = 5;
-
- // A human-readable documentation for this model. Markdown is allowed.
- optional string doc_string = 6;
-
- // The parameterized graph that is evaluated to execute the model.
- optional GraphProto graph = 7;
-
- // Named metadata values; keys should be distinct.
- repeated StringStringEntryProto metadata_props = 14;
-};
-
-// StringStringEntryProto follows the pattern for cross-proto-version maps.
-// See https://developers.google.com/protocol-buffers/docs/proto3#maps
-message StringStringEntryProto {
- optional string key = 1;
- optional string value= 2;
-};
-
-// Graphs
-//
-// A graph defines the computational logic of a model and is comprised of a parameterized
-// list of nodes that form a directed acyclic graph based on their inputs and outputs.
-// This is the equivalent of the "network" or "graph" in many deep learning
-// frameworks.
-message GraphProto {
- // The nodes in the graph, sorted topologically.
- repeated NodeProto node = 1;
-
- // The name of the graph.
- optional string name = 2; // namespace Graph
-
- // A list of named tensor values, used to specify constant inputs of the graph.
- // Each TensorProto entry must have a distinct name (within the list) that
- // also appears in the input list.
- repeated TensorProto initializer = 5;
-
- // A human-readable documentation for this graph. Markdown is allowed.
- optional string doc_string = 10;
-
- // The inputs and outputs of the graph.
- repeated ValueInfoProto input = 11;
- repeated ValueInfoProto output = 12;
-
- // Information for the values in the graph. The ValueInfoProto.name's
- // must be distinct. It is optional for a value to appear in value_info list.
- repeated ValueInfoProto value_info = 13;
-
- // DO NOT USE the following fields, they were deprecated from earlier versions.
- // repeated string input = 3;
- // repeated string output = 4;
- // optional int64 ir_version = 6;
- // optional int64 producer_version = 7;
- // optional string producer_tag = 8;
- // optional string domain = 9;
-}
-
-// Tensors
-//
-// A serialized tensor value.
-message TensorProto {
- enum DataType {
- UNDEFINED = 0;
- // Basic types.
- FLOAT = 1; // float
- UINT8 = 2; // uint8_t
- INT8 = 3; // int8_t
- UINT16 = 4; // uint16_t
- INT16 = 5; // int16_t
- INT32 = 6; // int32_t
- INT64 = 7; // int64_t
- STRING = 8; // string
- BOOL = 9; // bool
-
- // Advanced types
- FLOAT16 = 10;
- DOUBLE = 11;
- UINT32 = 12;
- UINT64 = 13;
- COMPLEX64 = 14; // complex with float32 real and imaginary components
- COMPLEX128 = 15; // complex with float64 real and imaginary components
- // Future extensions go here.
- }
-
- // The shape of the tensor.
- repeated int64 dims = 1;
-
- // The data type of the tensor.
- optional DataType data_type = 2;
-
- // For very large tensors, we may want to store them in chunks, in which
- // case the following fields will specify the segment that is stored in
- // the current TensorProto.
- message Segment {
- optional int64 begin = 1;
- optional int64 end = 2;
- }
- optional Segment segment = 3;
-
- // Tensor content must be organized in row-major order.
- //
- // Depending on the data_type field, exactly one of the fields below with
- // name ending in _data is used to store the elements of the tensor.
-
- // For float and complex64 values
- // Complex64 tensors are encoded as a single array of floats,
- // with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
- // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
- // is encoded as [1.0, 2.0 ,3.0 ,4.0]
- // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
- repeated float float_data = 4 [packed = true];
-
- // For int32, uint8, int8, uint16, int16, bool, and float16 values
- // float16 values must be bit-wise converted to an uint16_t prior
- // to writing to the buffer.
- // When this field is present, the data_type field MUST be
- // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
- repeated int32 int32_data = 5 [packed = true];
-
- // For strings.
- // Each element of string_data is a UTF-8 encoded Unicode
- // string. No trailing null, no leading BOM. The protobuf "string"
- // scalar type is not used to match ML community conventions.
- // When this field is present, the data_type field MUST be STRING
- repeated bytes string_data = 6;
-
- // For int64.
- // When this field is present, the data_type field MUST be INT64
- repeated int64 int64_data = 7 [packed = true];
-
- // Optionally, a name for the tensor.
- optional string name = 8; // namespace Value
-
- // A human-readable documentation for this tensor. Markdown is allowed.
- optional string doc_string = 12;
-
- // Serializations can either use one of the fields above, or use this
- // raw bytes field. The only exception is the string case, where one is
- // required to store the content in the repeated bytes string_data field.
- //
- // When this raw_data field is used to store tensor value, elements MUST
- // be stored in as fixed-width, little-endian order.
- // Floating-point data types MUST be stored in IEEE 754 format.
- // Complex64 elements must be written as two consecutive FLOAT values, real component first.
- // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
- // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
- //
- // Note: the advantage of specific field rather than the raw_data field is
- // that in some cases (e.g. int data), protobuf does a better packing via
- // variable length storage, and may lead to smaller binary footprint.
- // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
- optional bytes raw_data = 9;
-
- // For double
- // Complex64 tensors are encoded as a single array of doubles,
- // with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
- // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
- // is encoded as [1.0, 2.0 ,3.0 ,4.0]
- // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
- repeated double double_data = 10 [packed = true];
-
- // For uint64 and uint32 values
- // When this field is present, the data_type field MUST be
- // UINT32 or UINT64
- repeated uint64 uint64_data = 11 [packed = true];
-}
-
-// Defines a tensor shape. A dimension can be either an integer value
-// or a symbolic variable. A symbolic variable represents an unknown
-// dimension.
-message TensorShapeProto {
- message Dimension {
- oneof value {
- int64 dim_value = 1;
- string dim_param = 2; // namespace Shape
- };
- // Standard denotation can optionally be used to denote tensor
- // dimensions with standard semantic descriptions to ensure
- // that operations are applied to the correct axis of a tensor.
- optional string denotation = 3;
- };
- repeated Dimension dim = 1;
-}
-
-// A set of pre-defined constants to be used as values for
-// the standard denotation field in TensorShapeProto.Dimension
-// for semantic description of the tensor dimension.
-message DenotationConstProto {
- // Describe a batch number dimension.
- optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
- // Describe a channel dimension.
- optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
- // Describe a time dimension.
- optional string DATA_TIME = 3 [default = "DATA_TIME"];
- // Describe a feature dimension. This is typically a feature
- // dimension in RNN and/or spatial dimension in CNN.
- optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
- // Describe a filter in-channel dimension. This is the dimension
- // that is identical (in size) to the channel dimension of the input
- // image feature maps.
- optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
- // Describe a filter out channel dimension. This is the dimension
- // that is identical (int size) to the channel dimension of the output
- // image feature maps.
- optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
- // Describe a filter spatial dimension.
- optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
-}
-
-// Types
-//
-// The standard ONNX data types.
-message TypeProto {
-
- message Tensor {
- // This field MUST NOT have the value of UNDEFINED
- // This field MUST be present for this version of the IR.
- optional TensorProto.DataType elem_type = 1;
- optional TensorShapeProto shape = 2;
- }
-
-
- oneof value {
- // The type of a tensor.
- Tensor tensor_type = 1;
-
- }
-}
-
-// Operator Sets
-//
-// OperatorSets are uniquely identified by a (domain, opset_version) pair.
-message OperatorSetIdProto {
- // The domain of the operator set being identified.
- // The empty string ("") or absence of this field implies the operator
- // set that is defined as part of the ONNX specification.
- // This field MUST be present in this version of the IR when referring to any other operator set.
- optional string domain = 1;
-
- // The version of the operator set being identified.
- // This field MUST be present in this version of the IR.
- optional int64 version = 2;
-} \ No newline at end of file