summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-10-25 12:26:56 +0100
committerLester Solbakken <lesters@oath.com>2020-10-25 12:26:56 +0100
commit38e2a6a325db457456e04ce8385f23b12a5da54d (patch)
treee5e5906f0692831240bd898c9378e948c68a5d02
parent899f7210569b4f43c1531a4f4c12507b41a7f4f7 (diff)
Revert "Revert "Add type resolving for ONNX models""
This reverts commit 882d574ab53e8d10a2a8765a64487c20661dc63f.
-rw-r--r--config-model/pom.xml9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java27
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java219
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java19
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java141
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java97
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java154
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java2
-rw-r--r--config-model/src/main/protobuf/onnx.proto464
-rw-r--r--config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd2
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_dynamic_model.py12
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_model.py37
-rw-r--r--config-model/src/test/integration/onnx-model/files/dynamic_model.onnx13
-rw-r--r--config-model/src/test/integration/onnx-model/files/model.onnx34
-rw-r--r--config-model/src/test/integration/onnx-model/files/summary_model.onnx34
-rw-r--r--config-model/src/test/integration/onnx-model/searchdefinitions/test.sd36
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java76
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java36
-rw-r--r--fat-model-dependencies/pom.xml5
24 files changed, 1293 insertions, 149 deletions
diff --git a/config-model/pom.xml b/config-model/pom.xml
index 95e79fd09fb..c0751431d03 100644
--- a/config-model/pom.xml
+++ b/config-model/pom.xml
@@ -46,6 +46,11 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>provided</scope>
@@ -498,6 +503,10 @@
<updateReleaseInfo>true</updateReleaseInfo>
</configuration>
</plugin>
+ <plugin>
+ <groupId>com.github.os72</groupId>
+ <artifactId>protoc-jar-maven-plugin</artifactId>
+ </plugin>
</plugins>
</build>
</project>
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 4011ce43841..b153ff62e7d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
}
+ // A reference to an ONNX model?
+ Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
+ if (onnxFeatureType.isPresent()) {
+ return onnxFeatureType.get();
+ }
+
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(function);
}
+ private Optional<TensorType> onnxFeatureType(Reference reference) {
+ if ( ! reference.name().equals("onnxModel"))
+ return Optional.empty();
+
+ if ( ! featureTypes.containsKey(reference)) {
+ String configOrFileName = reference.arguments().expressions().get(0).toString();
+
+ // Look up standardized format as added in RankProfile
+ String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
+ String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
+
+ reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
+ if ( ! featureTypes.containsKey(reference)) {
+ throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
+ }
+ }
+
+ return Optional.of(featureTypes.get(reference));
+ }
+
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index c2fb2107604..58213186f78 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -2,14 +2,22 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
+import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
+import onnx.Onnx;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.List;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -21,16 +29,16 @@ public class OnnxModel {
public enum PathType {FILE, URI};
private final String name;
+ private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private List<OnnxNameMapping> inputMap = new ArrayList<>();
- private List<OnnxNameMapping> outputMap = new ArrayList<>();
+ private String defaultOutput = null;
+ private Map<String, String> inputMap = new HashMap<>();
+ private Map<String, String> outputMap = new HashMap<>();
- public PathType getPathType() {
- return pathType;
- }
-
- private PathType pathType = PathType.FILE;
+ private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
+ private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
+ private Map<String, TensorType> vespaTypes = new HashMap<>();
public OnnxModel(String name) {
this.name = name;
@@ -49,21 +57,52 @@ public class OnnxModel {
}
public void setUri(String uri) {
- Objects.requireNonNull(uri, "uri cannot be null");
- this.path = uri;
- this.pathType = PathType.URI;
+ throw new IllegalArgumentException("URI for ONNX models are not currently supported");
+ }
+
+ public PathType getPathType() {
+ return pathType;
+ }
+
+ public void setDefaultOutput(String onnxName) {
+ Objects.requireNonNull(onnxName, "Name cannot be null");
+ this.defaultOutput = onnxName;
}
public void addInputNameMapping(String onnxName, String vespaName) {
+ addInputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ if (overwrite || ! inputMap.containsKey(onnxName)) {
+ inputMap.put(onnxName, vespaName);
+ }
}
public void addOutputNameMapping(String onnxName, String vespaName) {
+ addOutputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ if (overwrite || ! outputMap.containsKey(onnxName)) {
+ outputMap.put(onnxName, vespaName);
+ }
+ }
+
+ public void addInputType(String onnxName, Onnx.TypeProto type) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(type, "Tensor type cannot be null");
+ inputTypes.put(onnxName, type);
+ }
+
+ public void addOutputType(String onnxName, Onnx.TypeProto type) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(type, "Tensor type cannot be null");
+ outputTypes.put(onnxName, type);
}
/** Initiate sending of this constant to some services over file distribution */
@@ -76,11 +115,16 @@ public class OnnxModel {
public String getName() { return name; }
public String getFileName() { return path; }
+ public Path getFilePath() { return Path.fromString(path); }
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
- public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
- public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
+ public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
+ public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
+
+ public String getDefaultOutput() {
+ return defaultOutput;
+ }
public void validate() {
if (path == null || path.isEmpty())
@@ -90,23 +134,142 @@ public class OnnxModel {
public String toString() {
StringBuilder b = new StringBuilder();
b.append("onnx-model '").append(name)
- .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
- .append("' with ref '").append(fileReference)
- .append("'");
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
return b.toString();
}
- public static class OnnxNameMapping {
- private String onnxName;
- private String vespaName;
+ /**
+ * Return the tensor type for an ONNX model for the given context.
+ * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
+ * type depends on the input types for the given context (rank profile).
+ */
+ public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
+ Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
+ if (onnxOutputType == null) {
+ throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
+ }
+ if (containsSymbolicDimensionSizes(onnxOutputType)) {
+ return getTensorTypeWithSymbolicDimensions(onnxOutputType, context);
+ }
+ return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
+ }
+
+ private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
+ Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context);
+ if (symbolicSizes.isEmpty()) {
+ return TensorType.empty; // Context is probably a rank profile not using this ONNX model
+ }
+ return typeFrom(onnxOutputType, symbolicSizes);
+ }
+
+ private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) {
+ Map<String, Long> symbolicSizes = new HashMap<>();
+ for (String onnxInputName : inputTypes.keySet()) {
+
+ Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
+ if ( ! containsSymbolicDimensionSizes(onnxType)) {
+ continue;
+ }
+
+ Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
+ if (vespaType.isEmpty()) {
+ return Collections.emptyMap();
+ }
+
+ var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
+ var vespaDimensions = vespaType.get().dimensions();
+ if (vespaDimensions.size() != onnxDimensions.size()) {
+ return Collections.emptyMap();
+ }
+
+ for (int i = 0; i < vespaDimensions.size(); ++i) {
+ if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) {
+ continue;
+ }
+ String symbolicName = onnxDimensions.get(i).getDimParam();
+ Long size = vespaDimensions.get(i).size().get();
+ if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
+ throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " +
+ "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
+ }
+ symbolicSizes.put(symbolicName, size);
+ }
+ }
+ return symbolicSizes;
+ }
+
+ private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
+ String source = inputMap.get(onnxInputName);
+ if (source != null) {
+ // Source is either a simple reference (query/attribute/constant)...
+ Optional<Reference> reference = Reference.simple(source);
+ if (reference.isPresent()) {
+ return Optional.of(context.getType(reference.get()));
+ }
+ // ... or a function
+ ExpressionFunction func = context.getFunction(source);
+ if (func != null) {
+ return Optional.of(func.getBody().type(context));
+ }
+ }
+ return Optional.empty(); // if this context does not contain this input
+ }
+
+ private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) {
+ return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue());
+ }
+
+ private static TensorType typeFrom(Onnx.TypeProto type) {
+ return typeFrom(type, null);
+ }
+
+ private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ long onnxDimensionSize = onnxDimension.getDimValue();
+ if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
+ onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
+ }
+ if (onnxDimensionSize == 0 && symbolicSizes != null) {
+ // This is for the case where all symbolic dimensions have
+ // different names, but can be resolved to a single dimension size.
+ Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
+ if (unknownSizes.size() == 1) {
+ onnxDimensionSize = unknownSizes.iterator().next();
+ }
+ }
+ if (onnxDimensionSize <= 0) {
+ throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
+ "ONNX type: " + type + " to Vespa tensor type.");
+ }
+ builder.indexed(dimensionName, onnxDimensionSize);
+ }
+ return builder.build();
+ }
- private OnnxNameMapping(String onnxName, String vespaName) {
- this.onnxName = onnxName;
- this.vespaName = vespaName;
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.FLOAT;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
}
- public String getOnnxName() { return onnxName; }
- public String getVespaName() { return vespaName; }
- public void setVespaName(String vespaName) { this.vespaName = vespaName; }
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index d309f48d6df..96c043bdb34 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
@@ -158,6 +159,10 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
+ private Map<String, OnnxModel> onnxModels() {
+ return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
+ }
+
private Stream<ImmutableSDField> allFields() {
if (search == null) return Stream.empty();
if (allFieldsList == null) {
@@ -821,6 +826,20 @@ public class RankProfile implements Cloneable {
}
}
+ // Add output types for ONNX models
+ for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
+ String modelName = entry.getKey();
+ OnnxModel model = entry.getValue();
+ Arguments args = new Arguments(new ReferenceNode(modelName));
+
+ TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
+ context.setType(new Reference("onnxModel", args, null), defaultOutputType);
+
+ for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
+ TensorType type = model.getTensorType(mapping.getKey(), context);
+ context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
+ }
+ }
return context;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 84442fedc48..22a32c8fd65 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName())));
- model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName())));
+ model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
+ model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 87eaaf0387a..56a5d539906 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) {
- for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) {
- String source = mapping.getVespaName();
+ for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
+ String source = mapping.getValue();
if (functionNames.contains(source)) {
- mapping.setVespaName("rankingExpression(" + source + ")");
+ onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")");
}
}
}
@@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>();
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
- ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch());
+ ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile);
if (referenceNode != replacedNode) {
replacedSummaryFeatures.add(replacedNode);
i.remove();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index ec517768ea9..d23a8376e7a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if ( ! feature.getName().equals("onnx")) return feature;
+ if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature;
try {
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index e1ad003e5bd..69cdae10e47 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,20 +1,36 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.yahoo.path.Path;
import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.vespa.model.ml.ConvertedModel;
+import com.yahoo.vespa.model.ml.FeatureArguments;
+import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;
/**
- * Transforms instances of the onnxModel ranking feature and generates
- * ONNX configuration if necessary.
+ * Transforms ONNX model features of the forms:
+ *
+ * onnxModel(config_name)
+ * onnxModel(config_name).output
+ * onnxModel("path/to/model")
+ * onnxModel("path/to/model").output
+ * onnxModel("path/to/model", "path/to/output")
+ * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
+ *
+ * To the format expected by the backend:
+ *
+ * onnxModel(config_name).output
*
* @author lesters
*/
@@ -33,85 +49,92 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
if (context.rankProfile() == null) return feature;
if (context.rankProfile().getSearch() == null) return feature;
- return transformFeature(feature, context.rankProfile().getSearch());
+ return transformFeature(feature, context.rankProfile());
}
- public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
- if (!feature.getName().equals("onnxModel")) return feature;
+ public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
+ ImmutableSearch search = rankProfile.getSearch();
+ final String featureName = feature.getName();
+ if ( ! featureName.equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
- "onnx-model config or a ONNX file.");
- if (arguments.expressions().size() > 2)
- throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
-
- // Validation that the file actually exists is handled when the file is added to file distribution.
- // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
-
- String modelConfigName;
- OnnxModel onnxModel;
- if (arguments.expressions().get(0) instanceof ReferenceNode) {
- modelConfigName = arguments.expressions().get(0).toString();
- onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
- }
- } else if (arguments.expressions().get(0) instanceof ConstantNode) {
+ throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " +
+ "onnx-model config or an ONNX file.");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
+
+ // Check that the model configuration "onnx-model" exists. If not defined, it should have been added
+ // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find
+ // the actual ONNX file, which can happen if we are restarting or upgrading an application using an
+ // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
+
+ String modelConfigName = getModelConfigName(feature.reference());
+ OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
String path = asString(arguments.expressions().get(0));
- modelConfigName = asValidIdentifier(path);
- onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- search.onnxModels().add(onnxModel);
- }
- } else {
- throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
+ ModelName modelName = new ModelName(null, Path.fromString(path), true);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile);
+ FeatureArguments featureArguments = new FeatureArguments(arguments);
+ return convertedModel.expression(featureArguments, null);
}
- String output = null;
- if (feature.getOutput() != null) {
- output = feature.getOutput();
- if ( ! hasOutputMapping(onnxModel, output)) {
- onnxModel.addOutputNameMapping(output, output);
+ String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
+ String output = getModelOutput(feature.reference(), defaultOutput);
+ if (! onnxModel.getOutputMap().containsValue(output)) {
+ throw new IllegalArgumentException(featureName + " argument '" + output +
+ "' output not found in model '" + onnxModel.getFileName() + "'");
+ }
+ return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
+ }
+
+ public static String getModelConfigName(Reference reference) {
+ if (reference.arguments().size() > 0) {
+ ExpressionNode expr = reference.arguments().expressions().get(0);
+ if (expr instanceof ReferenceNode) { // refers to onnx-model config
+ return expr.toString();
}
- } else if (arguments.expressions().size() > 1) {
- String name = asString(arguments.expressions().get(1));
- output = asValidIdentifier(name);
- if ( ! hasOutputMapping(onnxModel, output)) {
- onnxModel.addOutputNameMapping(name, output);
+ if (expr instanceof ConstantNode) { // refers to an file path
+ return asValidIdentifier(expr);
}
}
+ return null;
+ }
- // Replace feature with name of config
- ExpressionNode argument = new ReferenceNode(modelConfigName);
- return new ReferenceNode("onnxModel", List.of(argument), output);
-
+ public static String getModelOutput(Reference reference, String defaultOutput) {
+ if (reference.output() != null) {
+ return reference.output();
+ } else if (reference.arguments().expressions().size() == 2) {
+ return asValidIdentifier(reference.arguments().expressions().get(1));
+ } else if (reference.arguments().expressions().size() > 2) {
+ return asValidIdentifier(reference.arguments().expressions().get(2));
+ }
+ return defaultOutput;
}
- private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
- return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
+ public static String stripQuotes(String s) {
+ if (isNotQuoteSign(s.codePointAt(0))) return s;
+ if (isNotQuoteSign(s.codePointAt(s.length() - 1)))
+ throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
+ return s.substring(1, s.length()-1);
}
- private static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ public static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
- private static String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
+ private static String asValidIdentifier(ExpressionNode node) {
+ return asValidIdentifier(asString(node));
}
- private static boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
+ private static boolean isNotQuoteSign(int c) {
+ return c != '\'' && c != '"';
}
- private static String asValidIdentifier(String str) {
- return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ public static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
new file mode 100644
index 00000000000..afba88c135d
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -0,0 +1,97 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+
+import java.util.Map;
+
+/**
+ * Processes ONNX ranking features of the form:
+ *
+ * onnx("files/model.onnx", "path/to/output:1")
+ *
+ * And generates an "onnx-model" configuration as if it was defined in the schema:
+ *
+ * onnx-model files_model_onnx {
+ * file: "files/model.onnx"
+ * }
+ *
+ * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be
+ * processed after this.
+ *
+ * @author lesters
+ */
+public class OnnxModelConfigGenerator extends Processor {
+
+ public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+ for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) {
+ if (profile.getFirstPhaseRanking() != null) {
+ process(profile.getFirstPhaseRanking().getRoot());
+ }
+ if (profile.getSecondPhaseRanking() != null) {
+ process(profile.getSecondPhaseRanking().getRoot());
+ }
+ for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
+ process(function.getValue().function().getBody().getRoot());
+ }
+ for (ReferenceNode feature : profile.getSummaryFeatures()) {
+ process(feature);
+ }
+ }
+ }
+
+ private void process(ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ process((ReferenceNode)node);
+ } else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode) node).children()) {
+ process(child);
+ }
+ }
+ }
+
+ private void process(ReferenceNode feature) {
+ if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
+ if (feature.getArguments().size() > 0) {
+ if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
+ ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
+ String path = OnnxModelTransformer.stripQuotes(node.sourceString());
+ String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
+
+ // Only add the configuration if the model can actually be found.
+ if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ path = ApplicationPackage.MODELS_DIR.append(path).toString();
+ if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ return;
+ }
+ }
+
+ OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ }
+ }
+ }
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
new file mode 100644
index 00000000000..bead2e7e7c9
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -0,0 +1,154 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.component.Version;
+import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.vespa.defaults.Defaults;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+import onnx.Onnx;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Paths;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Processes every "onnx-model" element in the schema. Parses the model file,
+ * adds missing input and output mappings (assigning default names), and
+ * adds tensor types to all model inputs and outputs.
+ *
+ * Must be processed before RankingExpressingTypeResolver.
+ *
+ * @author lesters
+ */
+public class OnnxModelTypeResolver extends Processor {
+
+ public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+
+ for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
+ OnnxModel modelConfig = entry.getValue();
+ try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+
+ // Model inputs - if not defined, assumes a function is provided with a valid name
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
+ String onnxInputName = valueInfo.getName();
+ String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
+ modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
+ modelConfig.addInputType(onnxInputName, valueInfo.getType());
+ }
+
+ // Model outputs
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
+ String onnxOutputName = valueInfo.getName();
+ String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
+ modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
+ modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
+ }
+
+ // Set the first output as default
+ if ( ! model.getGraph().getOutputList().isEmpty()) {
+ modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
+ }
+
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+ }
+
+ static boolean modelFileExists(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (getFile(pathInApplicationPackage, app).exists()) {
+ return true;
+ }
+ if (getFileReference(pathInApplicationPackage, app).isPresent()) {
+ return true;
+ }
+ return false;
+ }
+
+ private InputStream openModelFile(Path path) throws FileNotFoundException {
+ ApplicationFile file;
+ Optional<FileReference> reference;
+ Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
+
+ if ((file = getFile(path)).exists()) {
+ return file.createInputStream();
+ }
+ if ((file = getFile(modelsPath)).exists()) {
+ return file.createInputStream();
+ }
+ if ((reference = getFileReference(path)).isPresent()) {
+ return openFromFileRepository(path, reference.get());
+ }
+ if ((reference = getFileReference(modelsPath)).isPresent()) {
+ return openFromFileRepository(modelsPath, reference.get());
+ }
+
+ throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
+ "in application package or file repository.");
+ }
+
+ private ApplicationFile getFile(Path path) {
+ return getFile(path, search.applicationPackage());
+ }
+
+ private static ApplicationFile getFile(Path path, ApplicationPackage app) {
+ return app.getFile(path);
+ }
+
+ private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
+ return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
+ }
+
+ public static String getFileRepositoryPath(Path path, String fileReference) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
+ return Paths.get(fileRefDir, fileReference, path.getName()).toString();
+ }
+
+ private Optional<FileReference> getFileReference(Path path) {
+ return getFileReference(path, search.applicationPackage());
+ }
+
+ private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
+ Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
+ if (fileRegistry.isPresent()) {
+ for (FileRegistry.Entry file : fileRegistry.get().export()) {
+ if (file.relativePath.equals(path.toString())) {
+ return Optional.of(file.reference);
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
+ if (app == null) return Optional.empty();
+ Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
+ return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index e8594c2a87f..1a3ef9e54b4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -74,6 +74,8 @@ public class Processing {
ReferenceFieldsProcessor::new,
FastAccessValidator::new,
ReservedFunctionNames::new,
+ OnnxModelConfigGenerator::new,
+ OnnxModelTypeResolver::new,
RankingExpressionTypeResolver::new,
// These should be last:
IndexingValidation::new,
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index c6c7969e466..d5c5183b01f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -1,20 +1,19 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation;
-import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.log.InvalidLogFormatException;
import java.util.logging.Level;
import com.yahoo.log.LogMessage;
import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver;
import com.yahoo.yolean.Exceptions;
import com.yahoo.system.ProcessExecuter;
import com.yahoo.text.StringUtilities;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.collections.Pair;
import com.yahoo.config.ConfigInstance;
-import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.config.search.ImportedFieldsConfig;
import com.yahoo.vespa.config.search.IndexschemaConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -31,7 +30,6 @@ import com.yahoo.vespa.model.search.SearchCluster;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
-import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.logging.Logger;
@@ -152,12 +150,9 @@ public class RankSetupValidator extends Validator {
// Assist verify-ranksetup in finding the actual ONNX model files
Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap();
if (models.values().size() > 0) {
- ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
- String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
List<String> config = new ArrayList<>(models.values().size() * 2);
for (OnnxModel model : models.values()) {
- String modelFilename = Paths.get(model.getFileName()).getFileName().toString();
- String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString();
+ String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference());
config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference()));
config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath));
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 943fcbf6c1d..5ee6ed02e61 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -150,7 +150,7 @@ public class ConvertedModel {
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
ExpressionFunction expression = selectExpression(arguments);
- if (sourceModel.isPresent()) // we should verify
+ if (sourceModel.isPresent() && context != null) // we should verify
verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
return expression.getBody().getRoot();
}
diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto
new file mode 100644
index 00000000000..dc6542867e0
--- /dev/null
+++ b/config-model/src/main/protobuf/onnx.proto
@@ -0,0 +1,464 @@
+//
+// WARNING: This file is automatically generated! Please edit onnx.in.proto.
+//
+
+
+// Copyright (c) Facebook Inc. and Microsoft Corporation.
+// Licensed under the MIT license.
+
+syntax = "proto2";
+
+package onnx;
+
+// Overview
+//
+// ONNX is an open specification that is comprised of the following components:
+//
+// 1) A definition of an extensible computation graph model.
+// 2) Definitions of standard data types.
+// 3) Definitions of built-in operators.
+//
+// This document describes the syntax of models and their computation graphs,
+// as well as the standard data types. Together, they are referred to as the ONNX
+// Intermediate Representation, or 'IR' for short.
+//
+// The normative semantic specification of the ONNX IR is found in docs/IR.md.
+// Definitions of the built-in neural network operators may be found in docs/Operators.md.
+
+// Notes
+//
+// Release
+//
+// We are still in the very early stage of defining ONNX. The current
+// version of ONNX is a starting point. While we are actively working
+// towards a complete spec, we would like to get the community involved
+// by sharing our working version of ONNX.
+//
+// Protobuf compatibility
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+// that is compatible with both protobuf v2 and v3. This means that we do not use any
+// protobuf features that are only available in one of the two versions.
+//
+// Here are the most notable contortions we have to carry out to work around
+// these limitations:
+//
+// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
+// of key-value pairs, where order does not matter and duplicates
+// are not allowed.
+
+
+// Versioning
+//
+// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
+//
+// To be compatible with both proto2 and proto3, we will use a version number
+// that is not defined by the default value but an explicit enum number.
+enum Version {
+ // proto3 requires the first enum value to be zero.
+ // We add this just to appease the compiler.
+ _START_VERSION = 0;
+ // The version field is always serialized and we will use it to store the
+ // version that the graph is generated from. This helps us set up version
+ // control. We should use version as
+ // xx(major) - xx(minor) - xxxx(bugfix)
+ // and we are starting with 0x00000001 (0.0.1), which was the
+ // version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x00000001;
+
+ // IR_VERSION 0.0.2 published on Oct 30, 2017
+ // - Added type discriminator to AttributeProto to support proto3 users
+ IR_VERSION_2017_10_30 = 0x00000002;
+
+ // IR VERSION 0.0.3 published on Nov 3, 2017
+ // - For operator versioning:
+ // - Added new message OperatorSetIdProto
+ // - Added opset_import in ModelProto
+ // - For vendor extensions, added domain in NodeProto
+ IR_VERSION = 0x00000003;
+}
+
+// Attributes
+//
+// A named attribute containing either singular float, integer, string, graph,
+// and tensor values, or repeated float, integer, string, graph, and tensor values.
+// An AttributeProto MUST contain the name field, and *only one* of the
+// following content fields, effectively enforcing a C/C++ union equivalent.
+message AttributeProto {
+
+ // Note: this enum is structurally identical to the OpSchema::AttrType
+ // enum defined in schema.h. If you rev one, you likely need to rev the other.
+ enum AttributeType {
+ UNDEFINED = 0;
+ FLOAT = 1;
+ INT = 2;
+ STRING = 3;
+ TENSOR = 4;
+ GRAPH = 5;
+
+ FLOATS = 6;
+ INTS = 7;
+ STRINGS = 8;
+ TENSORS = 9;
+ GRAPHS = 10;
+ }
+
+ // The name field MUST be present for this version of the IR.
+ optional string name = 1; // namespace Attribute
+
+ // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
+ // In this case, this AttributeProto does not contain data, and it's a reference of attribute
+ // in parent scope.
+ // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
+ optional string ref_attr_name = 21;
+
+ // A human-readable documentation for this attribute. Markdown is allowed.
+ optional string doc_string = 13;
+
+ // The type field MUST be present for this version of the IR.
+ // For 0.0.1 versions of the IR, this field was not defined, and
+ // implementations needed to use has_field hueristics to determine
+ // which value field was in use. For IR_VERSION 0.0.2 or later, this
+ // field MUST be set and match the f|i|s|t|... field in use. This
+ // change was made to accomodate proto3 implementations.
+ optional AttributeType type = 20; // discriminator that indicates which field below is in use
+
+ // Exactly ONE of the following fields must be present for this version of the IR
+ optional float f = 2; // float
+ optional int64 i = 3; // int
+ optional bytes s = 4; // UTF-8 string
+ optional TensorProto t = 5; // tensor value
+ optional GraphProto g = 6; // graph
+ // Do not use field below, it's deprecated.
+ // optional ValueProto v = 12; // value - subsumes everything but graph
+
+ repeated float floats = 7; // list of floats
+ repeated int64 ints = 8; // list of ints
+ repeated bytes strings = 9; // list of UTF-8 strings
+ repeated TensorProto tensors = 10; // list of tensors
+ repeated GraphProto graphs = 11; // list of graph
+}
+
+// Defines information on value, including the name, the type, and
+// the shape of the value.
+message ValueInfoProto {
+ // This field MUST be present in this version of the IR.
+ optional string name = 1; // namespace Value
+ // This field MUST be present in this version of the IR.
+ optional TypeProto type = 2;
+ // A human-readable documentation for this value. Markdown is allowed.
+ optional string doc_string = 3;
+}
+
+// Nodes
+//
+// Computation graphs are made up of a DAG of nodes, which represent what is
+// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
+//
+// For example, it can be a node of type "Conv" that takes in an image, a filter
+// tensor and a bias tensor, and produces the convolved output.
+message NodeProto {
+ repeated string input = 1; // namespace Value
+ repeated string output = 2; // namespace Value
+
+ // An optional identifier for this node in a graph.
+ // This field MAY be absent in ths version of the IR.
+ optional string name = 3; // namespace Node
+
+ // The symbolic identifier of the Operator to execute.
+ optional string op_type = 4; // namespace Operator
+ // The domain of the OperatorSet that specifies the operator named by op_type.
+ optional string domain = 7; // namespace Domain
+
+ // Additional named attributes.
+ repeated AttributeProto attribute = 5;
+
+ // A human-readable documentation for this node. Markdown is allowed.
+ optional string doc_string = 6;
+}
+
+// Models
+//
+// ModelProto is a top-level file/container format for bundling a ML model and
+// associating its computation graph with metadata.
+//
+// The semantics of the model are described by the associated GraphProto.
+message ModelProto {
+ // The version of the IR this model targets. See Version enum above.
+ // This field MUST be present.
+ optional int64 ir_version = 1;
+
+ // The OperatorSets this model relies on.
+ // All ModelProtos MUST have at least one entry that
+ // specifies which version of the ONNX OperatorSet is
+ // being imported.
+ //
+ // All nodes in the ModelProto's graph will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets.
+ repeated OperatorSetIdProto opset_import = 8;
+
+ // The name of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_name = 2;
+
+ // The version of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_version = 3;
+
+ // Domain name of the model.
+ // We use reverse domain names as name space indicators. For example:
+ // `com.facebook.fair` or `com.microsoft.cognitiveservices`
+ //
+ // Together with `model_version` and GraphProto.name, this forms the unique identity of
+ // the graph.
+ optional string domain = 4;
+
+ // The version of the graph encoded. See Version enum below.
+ optional int64 model_version = 5;
+
+ // A human-readable documentation for this model. Markdown is allowed.
+ optional string doc_string = 6;
+
+ // The parameterized graph that is evaluated to execute the model.
+ optional GraphProto graph = 7;
+
+ // Named metadata values; keys should be distinct.
+ repeated StringStringEntryProto metadata_props = 14;
+};
+
+// StringStringEntryProto follows the pattern for cross-proto-version maps.
+// See https://developers.google.com/protocol-buffers/docs/proto3#maps
+message StringStringEntryProto {
+ optional string key = 1;
+ optional string value= 2;
+};
+
+// Graphs
+//
+// A graph defines the computational logic of a model and is comprised of a parameterized
+// list of nodes that form a directed acyclic graph based on their inputs and outputs.
+// This is the equivalent of the "network" or "graph" in many deep learning
+// frameworks.
+message GraphProto {
+ // The nodes in the graph, sorted topologically.
+ repeated NodeProto node = 1;
+
+ // The name of the graph.
+ optional string name = 2; // namespace Graph
+
+ // A list of named tensor values, used to specify constant inputs of the graph.
+ // Each TensorProto entry must have a distinct name (within the list) that
+ // also appears in the input list.
+ repeated TensorProto initializer = 5;
+
+ // A human-readable documentation for this graph. Markdown is allowed.
+ optional string doc_string = 10;
+
+ // The inputs and outputs of the graph.
+ repeated ValueInfoProto input = 11;
+ repeated ValueInfoProto output = 12;
+
+ // Information for the values in the graph. The ValueInfoProto.name's
+ // must be distinct. It is optional for a value to appear in value_info list.
+ repeated ValueInfoProto value_info = 13;
+
+ // DO NOT USE the following fields, they were deprecated from earlier versions.
+ // repeated string input = 3;
+ // repeated string output = 4;
+ // optional int64 ir_version = 6;
+ // optional int64 producer_version = 7;
+ // optional string producer_tag = 8;
+ // optional string domain = 9;
+}
+
+// Tensors
+//
+// A serialized tensor value.
+message TensorProto {
+ enum DataType {
+ UNDEFINED = 0;
+ // Basic types.
+ FLOAT = 1; // float
+ UINT8 = 2; // uint8_t
+ INT8 = 3; // int8_t
+ UINT16 = 4; // uint16_t
+ INT16 = 5; // int16_t
+ INT32 = 6; // int32_t
+ INT64 = 7; // int64_t
+ STRING = 8; // string
+ BOOL = 9; // bool
+
+ // Advanced types
+ FLOAT16 = 10;
+ DOUBLE = 11;
+ UINT32 = 12;
+ UINT64 = 13;
+ COMPLEX64 = 14; // complex with float32 real and imaginary components
+ COMPLEX128 = 15; // complex with float64 real and imaginary components
+ // Future extensions go here.
+ }
+
+ // The shape of the tensor.
+ repeated int64 dims = 1;
+
+ // The data type of the tensor.
+ optional DataType data_type = 2;
+
+ // For very large tensors, we may want to store them in chunks, in which
+ // case the following fields will specify the segment that is stored in
+ // the current TensorProto.
+ message Segment {
+ optional int64 begin = 1;
+ optional int64 end = 2;
+ }
+ optional Segment segment = 3;
+
+ // Tensor content must be organized in row-major order.
+ //
+ // Depending on the data_type field, exactly one of the fields below with
+ // name ending in _data is used to store the elements of the tensor.
+
+ // For float and complex64 values
+ // Complex64 tensors are encoded as a single array of floats,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
+ repeated float float_data = 4 [packed = true];
+
+ // For int32, uint8, int8, uint16, int16, bool, and float16 values
+ // float16 values must be bit-wise converted to an uint16_t prior
+ // to writing to the buffer.
+ // When this field is present, the data_type field MUST be
+ // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
+ repeated int32 int32_data = 5 [packed = true];
+
+ // For strings.
+ // Each element of string_data is a UTF-8 encoded Unicode
+ // string. No trailing null, no leading BOM. The protobuf "string"
+ // scalar type is not used to match ML community conventions.
+ // When this field is present, the data_type field MUST be STRING
+ repeated bytes string_data = 6;
+
+ // For int64.
+ // When this field is present, the data_type field MUST be INT64
+ repeated int64 int64_data = 7 [packed = true];
+
+ // Optionally, a name for the tensor.
+ optional string name = 8; // namespace Value
+
+ // A human-readable documentation for this tensor. Markdown is allowed.
+ optional string doc_string = 12;
+
+ // Serializations can either use one of the fields above, or use this
+ // raw bytes field. The only exception is the string case, where one is
+ // required to store the content in the repeated bytes string_data field.
+ //
+ // When this raw_data field is used to store tensor value, elements MUST
+ // be stored in as fixed-width, little-endian order.
+ // Floating-point data types MUST be stored in IEEE 754 format.
+ // Complex64 elements must be written as two consecutive FLOAT values, real component first.
+ // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
+ // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
+ //
+ // Note: the advantage of specific field rather than the raw_data field is
+ // that in some cases (e.g. int data), protobuf does a better packing via
+ // variable length storage, and may lead to smaller binary footprint.
+ // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
+ optional bytes raw_data = 9;
+
+ // For double
+ // Complex64 tensors are encoded as a single array of doubles,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
+ repeated double double_data = 10 [packed = true];
+
+ // For uint64 and uint32 values
+ // When this field is present, the data_type field MUST be
+ // UINT32 or UINT64
+ repeated uint64 uint64_data = 11 [packed = true];
+}
+
+// Defines a tensor shape. A dimension can be either an integer value
+// or a symbolic variable. A symbolic variable represents an unknown
+// dimension.
+message TensorShapeProto {
+ message Dimension {
+ oneof value {
+ int64 dim_value = 1;
+ string dim_param = 2; // namespace Shape
+ };
+ // Standard denotation can optionally be used to denote tensor
+ // dimensions with standard semantic descriptions to ensure
+ // that operations are applied to the correct axis of a tensor.
+ optional string denotation = 3;
+ };
+ repeated Dimension dim = 1;
+}
+
+// A set of pre-defined constants to be used as values for
+// the standard denotation field in TensorShapeProto.Dimension
+// for semantic description of the tensor dimension.
+message DenotationConstProto {
+ // Describe a batch number dimension.
+ optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
+ // Describe a channel dimension.
+ optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
+ // Describe a time dimension.
+ optional string DATA_TIME = 3 [default = "DATA_TIME"];
+ // Describe a feature dimension. This is typically a feature
+ // dimension in RNN and/or spatial dimension in CNN.
+ optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
+ // Describe a filter in-channel dimension. This is the dimension
+ // that is identical (in size) to the channel dimension of the input
+ // image feature maps.
+ optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
+ // Describe a filter out channel dimension. This is the dimension
+ // that is identical (int size) to the channel dimension of the output
+ // image feature maps.
+ optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
+ // Describe a filter spatial dimension.
+ optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
+}
+
+// Types
+//
+// The standard ONNX data types.
+message TypeProto {
+
+ message Tensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST be present for this version of the IR.
+ optional TensorProto.DataType elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+
+ oneof value {
+ // The type of a tensor.
+ Tensor tensor_type = 1;
+
+ }
+}
+
+// Operator Sets
+//
+// OperatorSets are uniquely identified by a (domain, opset_version) pair.
+message OperatorSetIdProto {
+ // The domain of the operator set being identified.
+ // The empty string ("") or absence of this field implies the operator
+ // set that is defined as part of the ONNX specification.
+ // This field MUST be present in this version of the IR when referring to any other operator set.
+ optional string domain = 1;
+
+ // The version of the operator set being identified.
+ // This field MUST be present in this version of the IR.
+ optional int64 version = 2;
+} \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
index e9575af6010..cc73f2daff5 100644
--- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
+++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
@@ -18,7 +18,7 @@ search test {
}
function mnist_softmax_onnx() {
- expression: onnx("mnist_softmax")
+ expression: onnx_vespa("mnist_softmax")
}
function my_xgboost() {
diff --git a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py
new file mode 100755
index 00000000000..55df3a557e9
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py
@@ -0,0 +1,12 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "sequence"])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "sequence"])
+
+nodes = [helper.make_node('Identity', ['input'], ['output'])]
+graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT])
+model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py')
+onnx.save(model_def, 'dynamic_model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/create_model.py b/config-model/src/test/integration/onnx-model/files/create_model.py
new file mode 100755
index 00000000000..10ff92c2eda
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/create_model.py
@@ -0,0 +1,37 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('first_input', TensorProto.FLOAT, [2])
+INPUT_2 = helper.make_tensor_value_info('second/input:0', TensorProto.FLOAT, [2])
+INPUT_3 = helper.make_tensor_value_info('third_input', TensorProto.FLOAT, [2])
+OUTPUT_1 = helper.make_tensor_value_info('path/to/output:0', TensorProto.FLOAT, [2])
+OUTPUT_2 = helper.make_tensor_value_info('path/to/output:1', TensorProto.FLOAT, [2])
+OUTPUT_3 = helper.make_tensor_value_info('path/to/output:2', TensorProto.FLOAT, [2])
+
+nodes = [
+ helper.make_node(
+ 'Add',
+ ['first_input', 'second/input:0'],
+ ['path/to/output:0'],
+ ),
+ helper.make_node(
+ 'Add',
+ ['third_input', 'second/input:0'],
+ ['path/to/output:1']
+ ),
+ helper.make_node(
+ 'Add',
+ ['path/to/output:0', 'path/to/output:1'],
+ ['path/to/output:2']
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'simple_scoring',
+ [INPUT_1, INPUT_2, INPUT_3],
+ [OUTPUT_1, OUTPUT_2, OUTPUT_3]
+)
+model_def = helper.make_model(graph_def, producer_name='create_model.py')
+onnx.save(model_def, 'model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx
new file mode 100644
index 00000000000..6bbdad2d76e
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx
@@ -0,0 +1,13 @@
+create_dynamic_model.py:x
+
+inputoutput"Identitysimple_scoringZ$
+input
+
+batch
+
+sequenceb%
+output
+
+batch
+
+sequenceB \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/files/model.onnx b/config-model/src/test/integration/onnx-model/files/model.onnx
new file mode 100644
index 00000000000..f3898205c6a
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/model.onnx
@@ -0,0 +1,34 @@
+create_model.py:í
+4
+ first_input
+second/input:0path/to/output:0"Add
+4
+ third_input
+second/input:0path/to/output:1"Add
+;
+path/to/output:0
+path/to/output:1path/to/output:2"Addsimple_scoringZ
+ first_input
+
+
+Z
+second/input:0
+
+
+Z
+ third_input
+
+
+b
+path/to/output:0
+
+
+b
+path/to/output:1
+
+
+b
+path/to/output:2
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/files/summary_model.onnx b/config-model/src/test/integration/onnx-model/files/summary_model.onnx
new file mode 100644
index 00000000000..f3898205c6a
--- /dev/null
+++ b/config-model/src/test/integration/onnx-model/files/summary_model.onnx
@@ -0,0 +1,34 @@
+create_model.py:í
+4
+ first_input
+second/input:0path/to/output:0"Add
+4
+ third_input
+second/input:0path/to/output:1"Add
+;
+path/to/output:0
+path/to/output:1path/to/output:2"Addsimple_scoringZ
+ first_input
+
+
+Z
+second/input:0
+
+
+Z
+ third_input
+
+
+b
+path/to/output:0
+
+
+b
+path/to/output:1
+
+
+b
+path/to/output:2
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
index 0f0fa694e6f..6e9ba356293 100644
--- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
+++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
@@ -14,7 +14,7 @@ search test {
}
onnx-model my_model {
- file: files/ranking_model.onnx
+ file: files/model.onnx
input first_input: attribute(document_field)
input "second/input:0": constant(my_constant)
input "third_input": my_function
@@ -22,19 +22,25 @@ search test {
}
onnx-model another_model {
- file: files/ranking_model.onnx
+ file: files/model.onnx
input first_input: attribute(document_field)
input "second/input:0": constant(my_constant)
input "third_input": another_function
output "path/to/output:2": out
}
+ onnx-model dynamic_model {
+ file: files/dynamic_model.onnx
+ input input: my_function
+ output output: my_output
+ }
+
rank-profile test_model_config {
function my_function() {
expression: tensor(d0[2])(1)
}
first-phase {
- expression: onnxModel(my_model).out
+ expression: onnxModel(my_model).out{d0:1}
}
}
@@ -49,7 +55,7 @@ search test {
expression: my_function()
}
first-phase {
- expression: onnxModel("files/ranking_model.onnx", "path/to/output:1")
+ expression: onnxModel("files/model.onnx", "path/to/output:1"){d0:1}
}
}
@@ -62,9 +68,29 @@ search test {
}
summary-features {
onnxModel(another_model).out
- onnxModel("files/ranking_model.onnx", "path/to/output:2")
+ onnxModel("files/summary_model.onnx", "path/to/output:2")
+ }
+ }
+
+ rank-profile test_dynamic_model {
+ function my_function() {
+ expression: tensor(d0[1],d1[2])(d1)
+ }
+ first-phase {
+ expression: onnxModel(dynamic_model){d0:0,d1:1}
}
+ }
+ rank-profile test_dynamic_model_2 {
+ function my_function_2() {
+ expression: tensor(d0[1],d1[3])(d1)
+ }
+ function my_function() {
+ expression: my_function_2()
+ }
+ first-phase {
+ expression: onnxModel(dynamic_model){d0:0,d1:2}
+ }
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index d9b0c70dfdd..5060aafb55f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -25,43 +25,61 @@ public class RankingExpressionWithOnnxModelTestCase {
OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
((OnnxModelsConfig.Producer) db).getConfig(builder);
OnnxModelsConfig config = new OnnxModelsConfig(builder);
- assertEquals(3, config.model().size());
+ assertEquals(5, config.model().size());
- assertEquals("my_model", config.model(1).name());
+ assertEquals("my_model", config.model(0).name());
+ assertEquals(3, config.model(0).input().size());
+ assertEquals("second/input:0", config.model(0).input(0).name());
+ assertEquals("constant(my_constant)", config.model(0).input(0).source());
+ assertEquals("first_input", config.model(0).input(1).name());
+ assertEquals("attribute(document_field)", config.model(0).input(1).source());
+ assertEquals("third_input", config.model(0).input(2).name());
+ assertEquals("rankingExpression(my_function)", config.model(0).input(2).source());
+ assertEquals(3, config.model(0).output().size());
+ assertEquals("path/to/output:0", config.model(0).output(0).name());
+ assertEquals("out", config.model(0).output(0).as());
+ assertEquals("path/to/output:1", config.model(0).output(1).name());
+ assertEquals("path_to_output_1", config.model(0).output(1).as());
+ assertEquals("path/to/output:2", config.model(0).output(2).name());
+ assertEquals("path_to_output_2", config.model(0).output(2).as());
+
+ assertEquals("files_model_onnx", config.model(1).name());
assertEquals(3, config.model(1).input().size());
- assertEquals("first_input", config.model(1).input(0).name());
- assertEquals("attribute(document_field)", config.model(1).input(0).source());
- assertEquals("second/input:0", config.model(1).input(1).name());
- assertEquals("constant(my_constant)", config.model(1).input(1).source());
- assertEquals("third_input", config.model(1).input(2).name());
- assertEquals("rankingExpression(my_function)", config.model(1).input(2).source());
- assertEquals(1, config.model(1).output().size());
+ assertEquals(3, config.model(1).output().size());
assertEquals("path/to/output:0", config.model(1).output(0).name());
- assertEquals("out", config.model(1).output(0).as());
-
- assertEquals("files_ranking_model_onnx", config.model(0).name());
- assertEquals(0, config.model(0).input().size());
- assertEquals(2, config.model(0).output().size());
- assertEquals("path/to/output:1", config.model(0).output(0).name());
- assertEquals("path_to_output_1", config.model(0).output(0).as());
- assertEquals("path/to/output:2", config.model(0).output(1).name());
- assertEquals("path_to_output_2", config.model(0).output(1).as());
+ assertEquals("path_to_output_0", config.model(1).output(0).as());
+ assertEquals("path/to/output:1", config.model(1).output(1).name());
+ assertEquals("path_to_output_1", config.model(1).output(1).as());
+ assertEquals("path/to/output:2", config.model(1).output(2).name());
+ assertEquals("path_to_output_2", config.model(1).output(2).as());
+ assertEquals("files_model_onnx", config.model(1).name());
assertEquals("another_model", config.model(2).name());
assertEquals("third_input", config.model(2).input(2).name());
assertEquals("rankingExpression(another_function)", config.model(2).input(2).source());
+
+ assertEquals("files_summary_model_onnx", config.model(3).name());
+ assertEquals(3, config.model(3).input().size());
+ assertEquals(3, config.model(3).output().size());
+
+ assertEquals("dynamic_model", config.model(4).name());
+ assertEquals(1, config.model(4).input().size());
+ assertEquals(1, config.model(4).output().size());
+ assertEquals("rankingExpression(my_function)", config.model(4).input(0).source());
}
private void assertTransformedFeature(DocumentDatabase db) {
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);
- assertEquals(5, config.rankprofile().size());
+ assertEquals(7, config.rankprofile().size());
assertEquals("test_model_config", config.rankprofile(2).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name());
assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name());
- assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value());
+ assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name());
+ assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value());
assertEquals("test_generated_model_config", config.rankprofile(3).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name());
@@ -69,16 +87,28 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name());
assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name());
assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name());
- assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value());
+ assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name());
+ assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value());
assertEquals("test_summary_features", config.rankprofile(4).name());
assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name());
assertEquals("1", config.rankprofile(4).fef().property(3).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name());
- assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value());
+ assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name());
- assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value());
+ assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value());
+
+ assertEquals("test_dynamic_model", config.rankprofile(5).name());
+ assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name());
+ assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value());
+
+ assertEquals("test_dynamic_model_2", config.rankprofile(6).name());
+ assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name());
+ assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value());
+
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 6bf69907609..40bf970a313 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -70,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -87,7 +87,7 @@ public class RankingExpressionWithOnnxTestCase {
queryProfile,
queryProfileType);
RankProfileSearchFixture search = fixtureWith("query(mytensor)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
null,
null,
"Placeholder",
@@ -99,7 +99,7 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithDocumentFeature() {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
null,
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
@@ -117,7 +117,7 @@ public class RankingExpressionWithOnnxTestCase {
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType);
RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
@@ -129,21 +129,21 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testNestedOnnxReference() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "5 + sum(onnx('mnist_softmax.onnx'))");
+ "5 + sum(onnx_vespa('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutput() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'layer_add')");
+ "onnx_vespa('mnist_softmax.onnx', 'layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutputAndSignature() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'default.layer_add')");
+ "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -155,7 +155,7 @@ public class RankingExpressionWithOnnxTestCase {
new QueryProfileRegistry(),
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: onnx('mnist_softmax.onnx')" +
+ " expression: onnx_vespa('mnist_softmax.onnx')" +
" }\n" +
" }");
search.compileRankProfile("my_profile", applicationDir.append("models"));
@@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase {
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx('mnist_softmax.onnx'): " +
+ "onnx_vespa('mnist_softmax.onnx'): " +
"Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
@@ -175,13 +175,13 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithWrongFunctionType() {
try {
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)",
- "onnx('mnist_softmax.onnx')");
+ "onnx_vespa('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx('mnist_softmax.onnx'): " +
+ "onnx_vespa('mnist_softmax.onnx'): " +
"Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " +
"but this function returns tensor(d0[1],d5[10])",
Exceptions.toMessageString(expected));
@@ -192,13 +192,13 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceSpecifyingNonExistingOutput() {
try {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'y')");
+ "onnx_vespa('mnist_softmax.onnx', 'y')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx('mnist_softmax.onnx','y'): " +
+ "onnx_vespa('mnist_softmax.onnx','y'): " +
"No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add",
Exceptions.toMessageString(expected));
}
@@ -207,7 +207,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testImportingFromStoredExpressions() throws IOException {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')");
+ "onnx_vespa('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
@@ -218,7 +218,7 @@ public class RankingExpressionWithOnnxTestCase {
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')",
+ "onnx_vespa('mnist_softmax.onnx')",
null,
null,
"Placeholder",
@@ -243,7 +243,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d1[10],d2[784])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx('mnist_softmax.onnx')" +
+ " expression: onnx_vespa('mnist_softmax.onnx')" +
" }\n" +
" }" +
" rank-profile my_profile_child inherits my_profile {\n" +
@@ -288,7 +288,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d0[3])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx('" + name + ".onnx')" +
+ " expression: onnx_vespa('" + name + ".onnx')" +
" }\n" +
" }";
final String functionName = "imported_ml_function_" + name + "_exp_output";
@@ -310,7 +310,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d0[3])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx('" + name + ".onnx')" +
+ " expression: onnx_vespa('" + name + ".onnx')" +
" }\n" +
" }" +
" rank-profile my_profile_child inherits my_profile {\n" +
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index 4beaf6086a6..181ef6dffbd 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -221,5 +221,10 @@
<artifactId>jdisc_http_service</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
</dependencies>
</project>