summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-10-24 11:32:23 +0200
committerGitHub <noreply@github.com>2020-10-24 11:32:23 +0200
commit899f7210569b4f43c1531a4f4c12507b41a7f4f7 (patch)
tree4037006eb22b357e4d3e25d432bba1b062e2fd35
parent75f1af6d8f11883d2f5359efb0619896bb95cbb1 (diff)
parent882d574ab53e8d10a2a8765a64487c20661dc63f (diff)
Merge pull request #15025 from vespa-engine/revert-15010-lesters/resolve-onnx-model-types
Revert "Add type resolving for ONNX models"
-rw-r--r--config-model/pom.xml9
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java27
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java219
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java19
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java141
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java97
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java154
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java2
-rw-r--r--config-model/src/main/protobuf/onnx.proto464
-rw-r--r--config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd2
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_dynamic_model.py12
-rwxr-xr-xconfig-model/src/test/integration/onnx-model/files/create_model.py37
-rw-r--r--config-model/src/test/integration/onnx-model/files/dynamic_model.onnx13
-rw-r--r--config-model/src/test/integration/onnx-model/files/model.onnx34
-rw-r--r--config-model/src/test/integration/onnx-model/files/summary_model.onnx34
-rw-r--r--config-model/src/test/integration/onnx-model/searchdefinitions/test.sd36
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java76
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java36
-rw-r--r--fat-model-dependencies/pom.xml5
24 files changed, 149 insertions, 1293 deletions
diff --git a/config-model/pom.xml b/config-model/pom.xml
index c0751431d03..95e79fd09fb 100644
--- a/config-model/pom.xml
+++ b/config-model/pom.xml
@@ -46,11 +46,6 @@
<scope>test</scope>
</dependency>
<dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- <version>${protobuf.version}</version>
- </dependency>
- <dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<scope>provided</scope>
@@ -503,10 +498,6 @@
<updateReleaseInfo>true</updateReleaseInfo>
</configuration>
</plugin>
- <plugin>
- <groupId>com.github.os72</groupId>
- <artifactId>protoc-jar-maven-plugin</artifactId>
- </plugin>
</plugins>
</build>
</project>
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index b153ff62e7d..4011ce43841 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -159,12 +158,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
}
- // A reference to an ONNX model?
- Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
- if (onnxFeatureType.isPresent()) {
- return onnxFeatureType.get();
- }
-
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -217,26 +210,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(function);
}
- private Optional<TensorType> onnxFeatureType(Reference reference) {
- if ( ! reference.name().equals("onnxModel"))
- return Optional.empty();
-
- if ( ! featureTypes.containsKey(reference)) {
- String configOrFileName = reference.arguments().expressions().get(0).toString();
-
- // Look up standardized format as added in RankProfile
- String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
- String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
-
- reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
- if ( ! featureTypes.containsKey(reference)) {
- throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
- }
- }
-
- return Optional.of(featureTypes.get(reference));
- }
-
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 58213186f78..c2fb2107604 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -2,22 +2,14 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
-import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
-import onnx.Onnx;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
+import java.util.List;
import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -29,16 +21,16 @@ public class OnnxModel {
public enum PathType {FILE, URI};
private final String name;
- private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private String defaultOutput = null;
- private Map<String, String> inputMap = new HashMap<>();
- private Map<String, String> outputMap = new HashMap<>();
+ private List<OnnxNameMapping> inputMap = new ArrayList<>();
+ private List<OnnxNameMapping> outputMap = new ArrayList<>();
- private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
- private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
- private Map<String, TensorType> vespaTypes = new HashMap<>();
+ public PathType getPathType() {
+ return pathType;
+ }
+
+ private PathType pathType = PathType.FILE;
public OnnxModel(String name) {
this.name = name;
@@ -57,52 +49,21 @@ public class OnnxModel {
}
public void setUri(String uri) {
- throw new IllegalArgumentException("URI for ONNX models are not currently supported");
- }
-
- public PathType getPathType() {
- return pathType;
- }
-
- public void setDefaultOutput(String onnxName) {
- Objects.requireNonNull(onnxName, "Name cannot be null");
- this.defaultOutput = onnxName;
+ Objects.requireNonNull(uri, "uri cannot be null");
+ this.path = uri;
+ this.pathType = PathType.URI;
}
public void addInputNameMapping(String onnxName, String vespaName) {
- addInputNameMapping(onnxName, vespaName, true);
- }
-
- public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! inputMap.containsKey(onnxName)) {
- inputMap.put(onnxName, vespaName);
- }
+ this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
public void addOutputNameMapping(String onnxName, String vespaName) {
- addOutputNameMapping(onnxName, vespaName, true);
- }
-
- public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! outputMap.containsKey(onnxName)) {
- outputMap.put(onnxName, vespaName);
- }
- }
-
- public void addInputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- inputTypes.put(onnxName, type);
- }
-
- public void addOutputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- outputTypes.put(onnxName, type);
+ this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
/** Initiate sending of this constant to some services over file distribution */
@@ -115,16 +76,11 @@ public class OnnxModel {
public String getName() { return name; }
public String getFileName() { return path; }
- public Path getFilePath() { return Path.fromString(path); }
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
- public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
- public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
-
- public String getDefaultOutput() {
- return defaultOutput;
- }
+ public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
+ public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
public void validate() {
if (path == null || path.isEmpty())
@@ -134,142 +90,23 @@ public class OnnxModel {
public String toString() {
StringBuilder b = new StringBuilder();
b.append("onnx-model '").append(name)
- .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
- .append("' with ref '").append(fileReference)
- .append("'");
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
return b.toString();
}
- /**
- * Return the tensor type for an ONNX model for the given context.
- * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
- * type depends on the input types for the given context (rank profile).
- */
- public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
- Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
- if (onnxOutputType == null) {
- throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
- }
- if (containsSymbolicDimensionSizes(onnxOutputType)) {
- return getTensorTypeWithSymbolicDimensions(onnxOutputType, context);
- }
- return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
- }
-
- private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
- Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context);
- if (symbolicSizes.isEmpty()) {
- return TensorType.empty; // Context is probably a rank profile not using this ONNX model
- }
- return typeFrom(onnxOutputType, symbolicSizes);
- }
-
- private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) {
- Map<String, Long> symbolicSizes = new HashMap<>();
- for (String onnxInputName : inputTypes.keySet()) {
-
- Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
- if ( ! containsSymbolicDimensionSizes(onnxType)) {
- continue;
- }
-
- Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
- if (vespaType.isEmpty()) {
- return Collections.emptyMap();
- }
-
- var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
- var vespaDimensions = vespaType.get().dimensions();
- if (vespaDimensions.size() != onnxDimensions.size()) {
- return Collections.emptyMap();
- }
-
- for (int i = 0; i < vespaDimensions.size(); ++i) {
- if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) {
- continue;
- }
- String symbolicName = onnxDimensions.get(i).getDimParam();
- Long size = vespaDimensions.get(i).size().get();
- if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
- throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " +
- "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
- }
- symbolicSizes.put(symbolicName, size);
- }
- }
- return symbolicSizes;
- }
-
- private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
- String source = inputMap.get(onnxInputName);
- if (source != null) {
- // Source is either a simple reference (query/attribute/constant)...
- Optional<Reference> reference = Reference.simple(source);
- if (reference.isPresent()) {
- return Optional.of(context.getType(reference.get()));
- }
- // ... or a function
- ExpressionFunction func = context.getFunction(source);
- if (func != null) {
- return Optional.of(func.getBody().type(context));
- }
- }
- return Optional.empty(); // if this context does not contain this input
- }
-
- private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) {
- return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue());
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type) {
- return typeFrom(type, null);
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) {
- String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- long onnxDimensionSize = onnxDimension.getDimValue();
- if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
- onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
- }
- if (onnxDimensionSize == 0 && symbolicSizes != null) {
- // This is for the case where all symbolic dimensions have
- // different names, but can be resolved to a single dimension size.
- Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
- if (unknownSizes.size() == 1) {
- onnxDimensionSize = unknownSizes.iterator().next();
- }
- }
- if (onnxDimensionSize <= 0) {
- throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
- "ONNX type: " + type + " to Vespa tensor type.");
- }
- builder.indexed(dimensionName, onnxDimensionSize);
- }
- return builder.build();
- }
+ public static class OnnxNameMapping {
+ private String onnxName;
+ private String vespaName;
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT8: return TensorType.Value.FLOAT;
- case INT16: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.FLOAT;
- case UINT8: return TensorType.Value.FLOAT;
- case UINT16: return TensorType.Value.FLOAT;
- case UINT32: return TensorType.Value.FLOAT;
- case UINT64: return TensorType.Value.FLOAT;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
+ private OnnxNameMapping(String onnxName, String vespaName) {
+ this.onnxName = onnxName;
+ this.vespaName = vespaName;
}
+ public String getOnnxName() { return onnxName; }
+ public String getVespaName() { return vespaName; }
+ public void setVespaName(String vespaName) { this.vespaName = vespaName; }
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 96c043bdb34..d309f48d6df 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -18,7 +18,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
@@ -159,10 +158,6 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
- private Map<String, OnnxModel> onnxModels() {
- return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
- }
-
private Stream<ImmutableSDField> allFields() {
if (search == null) return Stream.empty();
if (allFieldsList == null) {
@@ -826,20 +821,6 @@ public class RankProfile implements Cloneable {
}
}
- // Add output types for ONNX models
- for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
- String modelName = entry.getKey();
- OnnxModel model = entry.getValue();
- Arguments args = new Arguments(new ReferenceNode(modelName));
-
- TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
- context.setType(new Reference("onnxModel", args, null), defaultOutputType);
-
- for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
- TensorType type = model.getTensorType(mapping.getKey(), context);
- context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
- }
- }
return context;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 22a32c8fd65..84442fedc48 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
- model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName())));
+ model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName())));
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 56a5d539906..87eaaf0387a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) {
- for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
- String source = mapping.getValue();
+ for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) {
+ String source = mapping.getVespaName();
if (functionNames.contains(source)) {
- onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")");
+ mapping.setVespaName("rankingExpression(" + source + ")");
}
}
}
@@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>();
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
- ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile);
+ ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch());
if (referenceNode != replacedNode) {
replacedSummaryFeatures.add(replacedNode);
i.remove();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index d23a8376e7a..ec517768ea9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature;
+ if ( ! feature.getName().equals("onnx")) return feature;
try {
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index 69cdae10e47..e1ad003e5bd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,36 +1,20 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import com.yahoo.path.Path;
import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.vespa.model.ml.ConvertedModel;
-import com.yahoo.vespa.model.ml.FeatureArguments;
-import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;
/**
- * Transforms ONNX model features of the forms:
- *
- * onnxModel(config_name)
- * onnxModel(config_name).output
- * onnxModel("path/to/model")
- * onnxModel("path/to/model").output
- * onnxModel("path/to/model", "path/to/output")
- * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
- *
- * To the format expected by the backend:
- *
- * onnxModel(config_name).output
+ * Transforms instances of the onnxModel ranking feature and generates
+ * ONNX configuration if necessary.
*
* @author lesters
*/
@@ -49,92 +33,85 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
if (context.rankProfile() == null) return feature;
if (context.rankProfile().getSearch() == null) return feature;
- return transformFeature(feature, context.rankProfile());
+ return transformFeature(feature, context.rankProfile().getSearch());
}
- public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
- ImmutableSearch search = rankProfile.getSearch();
- final String featureName = feature.getName();
- if ( ! featureName.equals("onnxModel")) return feature;
+ public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
+ if (!feature.getName().equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " +
- "onnx-model config or an ONNX file.");
- if (arguments.expressions().size() > 3)
- throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
-
- // Check that the model configuration "onnx-model" exists. If not defined, it should have been added
- // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find
- // the actual ONNX file, which can happen if we are restarting or upgrading an application using an
- // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
-
- String modelConfigName = getModelConfigName(feature.reference());
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
+ throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
+ "onnx-model config or a ONNX file.");
+ if (arguments.expressions().size() > 2)
+ throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
+
+ // Validation that the file actually exists is handled when the file is added to file distribution.
+ // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
+
+ String modelConfigName;
+ OnnxModel onnxModel;
+ if (arguments.expressions().get(0) instanceof ReferenceNode) {
+ modelConfigName = arguments.expressions().get(0).toString();
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
+ }
+ } else if (arguments.expressions().get(0) instanceof ConstantNode) {
String path = asString(arguments.expressions().get(0));
- ModelName modelName = new ModelName(null, Path.fromString(path), true);
- ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile);
- FeatureArguments featureArguments = new FeatureArguments(arguments);
- return convertedModel.expression(featureArguments, null);
- }
-
- String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
- String output = getModelOutput(feature.reference(), defaultOutput);
- if (! onnxModel.getOutputMap().containsValue(output)) {
- throw new IllegalArgumentException(featureName + " argument '" + output +
- "' output not found in model '" + onnxModel.getFileName() + "'");
+ modelConfigName = asValidIdentifier(path);
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ } else {
+ throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
}
- return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
- }
- public static String getModelConfigName(Reference reference) {
- if (reference.arguments().size() > 0) {
- ExpressionNode expr = reference.arguments().expressions().get(0);
- if (expr instanceof ReferenceNode) { // refers to onnx-model config
- return expr.toString();
+ String output = null;
+ if (feature.getOutput() != null) {
+ output = feature.getOutput();
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(output, output);
}
- if (expr instanceof ConstantNode) { // refers to an file path
- return asValidIdentifier(expr);
+ } else if (arguments.expressions().size() > 1) {
+ String name = asString(arguments.expressions().get(1));
+ output = asValidIdentifier(name);
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(name, output);
}
}
- return null;
- }
- public static String getModelOutput(Reference reference, String defaultOutput) {
- if (reference.output() != null) {
- return reference.output();
- } else if (reference.arguments().expressions().size() == 2) {
- return asValidIdentifier(reference.arguments().expressions().get(1));
- } else if (reference.arguments().expressions().size() > 2) {
- return asValidIdentifier(reference.arguments().expressions().get(2));
- }
- return defaultOutput;
+ // Replace feature with name of config
+ ExpressionNode argument = new ReferenceNode(modelConfigName);
+ return new ReferenceNode("onnxModel", List.of(argument), output);
+
}
- public static String stripQuotes(String s) {
- if (isNotQuoteSign(s.codePointAt(0))) return s;
- if (isNotQuoteSign(s.codePointAt(s.length() - 1)))
- throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
- return s.substring(1, s.length()-1);
+ private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
+ return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
}
- public static String asValidIdentifier(String str) {
- return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ private static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
- private static String asValidIdentifier(ExpressionNode node) {
- return asValidIdentifier(asString(node));
+ private static String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
}
- private static boolean isNotQuoteSign(int c) {
- return c != '\'' && c != '"';
+ private static boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
}
- public static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ private static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
deleted file mode 100644
index afba88c135d..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchdefinition.processing;
-
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.application.api.DeployLogger;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankProfileRegistry;
-import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
-import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
-
-import java.util.Map;
-
-/**
- * Processes ONNX ranking features of the form:
- *
- * onnx("files/model.onnx", "path/to/output:1")
- *
- * And generates an "onnx-model" configuration as if it was defined in the schema:
- *
- * onnx-model files_model_onnx {
- * file: "files/model.onnx"
- * }
- *
- * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be
- * processed after this.
- *
- * @author lesters
- */
-public class OnnxModelConfigGenerator extends Processor {
-
- public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
- super(search, deployLogger, rankProfileRegistry, queryProfiles);
- }
-
- @Override
- public void process(boolean validate, boolean documentsOnly) {
- if (documentsOnly) return;
- for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) {
- if (profile.getFirstPhaseRanking() != null) {
- process(profile.getFirstPhaseRanking().getRoot());
- }
- if (profile.getSecondPhaseRanking() != null) {
- process(profile.getSecondPhaseRanking().getRoot());
- }
- for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
- process(function.getValue().function().getBody().getRoot());
- }
- for (ReferenceNode feature : profile.getSummaryFeatures()) {
- process(feature);
- }
- }
- }
-
- private void process(ExpressionNode node) {
- if (node instanceof ReferenceNode) {
- process((ReferenceNode)node);
- } else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode) node).children()) {
- process(child);
- }
- }
- }
-
- private void process(ReferenceNode feature) {
- if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
- if (feature.getArguments().size() > 0) {
- if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
- ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
- String path = OnnxModelTransformer.stripQuotes(node.sourceString());
- String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
-
- // Only add the configuration if the model can actually be found.
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
- path = ApplicationPackage.MODELS_DIR.append(path).toString();
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
- return;
- }
- }
-
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- search.onnxModels().add(onnxModel);
- }
- }
- }
- }
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
deleted file mode 100644
index bead2e7e7c9..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchdefinition.processing;
-
-import com.yahoo.cloud.config.ConfigserverConfig;
-import com.yahoo.component.Version;
-import com.yahoo.config.FileReference;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.application.api.DeployLogger;
-import com.yahoo.config.application.api.FileRegistry;
-import com.yahoo.path.Path;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfileRegistry;
-import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
-import com.yahoo.vespa.defaults.Defaults;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.file.Paths;
-import java.util.Map;
-import java.util.Optional;
-
-/**
- * Processes every "onnx-model" element in the schema. Parses the model file,
- * adds missing input and output mappings (assigning default names), and
- * adds tensor types to all model inputs and outputs.
- *
- * Must be processed before RankingExpressingTypeResolver.
- *
- * @author lesters
- */
-public class OnnxModelTypeResolver extends Processor {
-
- public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
- super(search, deployLogger, rankProfileRegistry, queryProfiles);
- }
-
- @Override
- public void process(boolean validate, boolean documentsOnly) {
- if (documentsOnly) return;
-
- for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
- OnnxModel modelConfig = entry.getValue();
- try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
-
- // Model inputs - if not defined, assumes a function is provided with a valid name
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
- String onnxInputName = valueInfo.getName();
- String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
- modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
- modelConfig.addInputType(onnxInputName, valueInfo.getType());
- }
-
- // Model outputs
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
- String onnxOutputName = valueInfo.getName();
- String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
- modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
- modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
- }
-
- // Set the first output as default
- if ( ! model.getGraph().getOutputList().isEmpty()) {
- modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
- }
-
- } catch (IOException e) {
- throw new IllegalArgumentException("Unable to parse ONNX model", e);
- }
- }
- }
-
- static boolean modelFileExists(String path, ApplicationPackage app) {
- Path pathInApplicationPackage = Path.fromString(path);
- if (getFile(pathInApplicationPackage, app).exists()) {
- return true;
- }
- if (getFileReference(pathInApplicationPackage, app).isPresent()) {
- return true;
- }
- return false;
- }
-
- private InputStream openModelFile(Path path) throws FileNotFoundException {
- ApplicationFile file;
- Optional<FileReference> reference;
- Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
-
- if ((file = getFile(path)).exists()) {
- return file.createInputStream();
- }
- if ((file = getFile(modelsPath)).exists()) {
- return file.createInputStream();
- }
- if ((reference = getFileReference(path)).isPresent()) {
- return openFromFileRepository(path, reference.get());
- }
- if ((reference = getFileReference(modelsPath)).isPresent()) {
- return openFromFileRepository(modelsPath, reference.get());
- }
-
- throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
- "in application package or file repository.");
- }
-
- private ApplicationFile getFile(Path path) {
- return getFile(path, search.applicationPackage());
- }
-
- private static ApplicationFile getFile(Path path, ApplicationPackage app) {
- return app.getFile(path);
- }
-
- private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
- return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
- }
-
- public static String getFileRepositoryPath(Path path, String fileReference) {
- ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
- String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
- return Paths.get(fileRefDir, fileReference, path.getName()).toString();
- }
-
- private Optional<FileReference> getFileReference(Path path) {
- return getFileReference(path, search.applicationPackage());
- }
-
- private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
- Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
- if (fileRegistry.isPresent()) {
- for (FileRegistry.Entry file : fileRegistry.get().export()) {
- if (file.relativePath.equals(path.toString())) {
- return Optional.of(file.reference);
- }
- }
- }
- return Optional.empty();
- }
-
- private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
- if (app == null) return Optional.empty();
- Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
- return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index 1a3ef9e54b4..e8594c2a87f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -74,8 +74,6 @@ public class Processing {
ReferenceFieldsProcessor::new,
FastAccessValidator::new,
ReservedFunctionNames::new,
- OnnxModelConfigGenerator::new,
- OnnxModelTypeResolver::new,
RankingExpressionTypeResolver::new,
// These should be last:
IndexingValidation::new,
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index d5c5183b01f..c6c7969e466 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -1,19 +1,20 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation;
+import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.log.InvalidLogFormatException;
import java.util.logging.Level;
import com.yahoo.log.LogMessage;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver;
import com.yahoo.yolean.Exceptions;
import com.yahoo.system.ProcessExecuter;
import com.yahoo.text.StringUtilities;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.collections.Pair;
import com.yahoo.config.ConfigInstance;
+import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.config.search.ImportedFieldsConfig;
import com.yahoo.vespa.config.search.IndexschemaConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -30,6 +31,7 @@ import com.yahoo.vespa.model.search.SearchCluster;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
+import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.logging.Logger;
@@ -150,9 +152,12 @@ public class RankSetupValidator extends Validator {
// Assist verify-ranksetup in finding the actual ONNX model files
Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap();
if (models.values().size() > 0) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
List<String> config = new ArrayList<>(models.values().size() * 2);
for (OnnxModel model : models.values()) {
- String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference());
+ String modelFilename = Paths.get(model.getFileName()).getFileName().toString();
+ String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString();
config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference()));
config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath));
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 5ee6ed02e61..943fcbf6c1d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -150,7 +150,7 @@ public class ConvertedModel {
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
ExpressionFunction expression = selectExpression(arguments);
- if (sourceModel.isPresent() && context != null) // we should verify
+ if (sourceModel.isPresent()) // we should verify
verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
return expression.getBody().getRoot();
}
diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto
deleted file mode 100644
index dc6542867e0..00000000000
--- a/config-model/src/main/protobuf/onnx.proto
+++ /dev/null
@@ -1,464 +0,0 @@
-//
-// WARNING: This file is automatically generated! Please edit onnx.in.proto.
-//
-
-
-// Copyright (c) Facebook Inc. and Microsoft Corporation.
-// Licensed under the MIT license.
-
-syntax = "proto2";
-
-package onnx;
-
-// Overview
-//
-// ONNX is an open specification that is comprised of the following components:
-//
-// 1) A definition of an extensible computation graph model.
-// 2) Definitions of standard data types.
-// 3) Definitions of built-in operators.
-//
-// This document describes the syntax of models and their computation graphs,
-// as well as the standard data types. Together, they are referred to as the ONNX
-// Intermediate Representation, or 'IR' for short.
-//
-// The normative semantic specification of the ONNX IR is found in docs/IR.md.
-// Definitions of the built-in neural network operators may be found in docs/Operators.md.
-
-// Notes
-//
-// Release
-//
-// We are still in the very early stage of defining ONNX. The current
-// version of ONNX is a starting point. While we are actively working
-// towards a complete spec, we would like to get the community involved
-// by sharing our working version of ONNX.
-//
-// Protobuf compatibility
-//
-// To simplify framework compatibility, ONNX is defined using the subset of protobuf
-// that is compatible with both protobuf v2 and v3. This means that we do not use any
-// protobuf features that are only available in one of the two versions.
-//
-// Here are the most notable contortions we have to carry out to work around
-// these limitations:
-//
-// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
-// of key-value pairs, where order does not matter and duplicates
-// are not allowed.
-
-
-// Versioning
-//
-// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
-//
-// To be compatible with both proto2 and proto3, we will use a version number
-// that is not defined by the default value but an explicit enum number.
-enum Version {
- // proto3 requires the first enum value to be zero.
- // We add this just to appease the compiler.
- _START_VERSION = 0;
- // The version field is always serialized and we will use it to store the
- // version that the graph is generated from. This helps us set up version
- // control. We should use version as
- // xx(major) - xx(minor) - xxxx(bugfix)
- // and we are starting with 0x00000001 (0.0.1), which was the
- // version we published on Oct 10, 2017.
- IR_VERSION_2017_10_10 = 0x00000001;
-
- // IR_VERSION 0.0.2 published on Oct 30, 2017
- // - Added type discriminator to AttributeProto to support proto3 users
- IR_VERSION_2017_10_30 = 0x00000002;
-
- // IR VERSION 0.0.3 published on Nov 3, 2017
- // - For operator versioning:
- // - Added new message OperatorSetIdProto
- // - Added opset_import in ModelProto
- // - For vendor extensions, added domain in NodeProto
- IR_VERSION = 0x00000003;
-}
-
-// Attributes
-//
-// A named attribute containing either singular float, integer, string, graph,
-// and tensor values, or repeated float, integer, string, graph, and tensor values.
-// An AttributeProto MUST contain the name field, and *only one* of the
-// following content fields, effectively enforcing a C/C++ union equivalent.
-message AttributeProto {
-
- // Note: this enum is structurally identical to the OpSchema::AttrType
- // enum defined in schema.h. If you rev one, you likely need to rev the other.
- enum AttributeType {
- UNDEFINED = 0;
- FLOAT = 1;
- INT = 2;
- STRING = 3;
- TENSOR = 4;
- GRAPH = 5;
-
- FLOATS = 6;
- INTS = 7;
- STRINGS = 8;
- TENSORS = 9;
- GRAPHS = 10;
- }
-
- // The name field MUST be present for this version of the IR.
- optional string name = 1; // namespace Attribute
-
- // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
- // In this case, this AttributeProto does not contain data, and it's a reference of attribute
- // in parent scope.
- // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
- optional string ref_attr_name = 21;
-
- // A human-readable documentation for this attribute. Markdown is allowed.
- optional string doc_string = 13;
-
- // The type field MUST be present for this version of the IR.
- // For 0.0.1 versions of the IR, this field was not defined, and
- // implementations needed to use has_field hueristics to determine
- // which value field was in use. For IR_VERSION 0.0.2 or later, this
- // field MUST be set and match the f|i|s|t|... field in use. This
- // change was made to accomodate proto3 implementations.
- optional AttributeType type = 20; // discriminator that indicates which field below is in use
-
- // Exactly ONE of the following fields must be present for this version of the IR
- optional float f = 2; // float
- optional int64 i = 3; // int
- optional bytes s = 4; // UTF-8 string
- optional TensorProto t = 5; // tensor value
- optional GraphProto g = 6; // graph
- // Do not use field below, it's deprecated.
- // optional ValueProto v = 12; // value - subsumes everything but graph
-
- repeated float floats = 7; // list of floats
- repeated int64 ints = 8; // list of ints
- repeated bytes strings = 9; // list of UTF-8 strings
- repeated TensorProto tensors = 10; // list of tensors
- repeated GraphProto graphs = 11; // list of graph
-}
-
-// Defines information on value, including the name, the type, and
-// the shape of the value.
-message ValueInfoProto {
- // This field MUST be present in this version of the IR.
- optional string name = 1; // namespace Value
- // This field MUST be present in this version of the IR.
- optional TypeProto type = 2;
- // A human-readable documentation for this value. Markdown is allowed.
- optional string doc_string = 3;
-}
-
-// Nodes
-//
-// Computation graphs are made up of a DAG of nodes, which represent what is
-// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
-//
-// For example, it can be a node of type "Conv" that takes in an image, a filter
-// tensor and a bias tensor, and produces the convolved output.
-message NodeProto {
- repeated string input = 1; // namespace Value
- repeated string output = 2; // namespace Value
-
- // An optional identifier for this node in a graph.
- // This field MAY be absent in ths version of the IR.
- optional string name = 3; // namespace Node
-
- // The symbolic identifier of the Operator to execute.
- optional string op_type = 4; // namespace Operator
- // The domain of the OperatorSet that specifies the operator named by op_type.
- optional string domain = 7; // namespace Domain
-
- // Additional named attributes.
- repeated AttributeProto attribute = 5;
-
- // A human-readable documentation for this node. Markdown is allowed.
- optional string doc_string = 6;
-}
-
-// Models
-//
-// ModelProto is a top-level file/container format for bundling a ML model and
-// associating its computation graph with metadata.
-//
-// The semantics of the model are described by the associated GraphProto.
-message ModelProto {
- // The version of the IR this model targets. See Version enum above.
- // This field MUST be present.
- optional int64 ir_version = 1;
-
- // The OperatorSets this model relies on.
- // All ModelProtos MUST have at least one entry that
- // specifies which version of the ONNX OperatorSet is
- // being imported.
- //
- // All nodes in the ModelProto's graph will bind against the operator
- // with the same-domain/same-op_type operator with the HIGHEST version
- // in the referenced operator sets.
- repeated OperatorSetIdProto opset_import = 8;
-
- // The name of the framework or tool used to generate this model.
- // This field SHOULD be present to indicate which implementation/tool/framework
- // emitted the model.
- optional string producer_name = 2;
-
- // The version of the framework or tool used to generate this model.
- // This field SHOULD be present to indicate which implementation/tool/framework
- // emitted the model.
- optional string producer_version = 3;
-
- // Domain name of the model.
- // We use reverse domain names as name space indicators. For example:
- // `com.facebook.fair` or `com.microsoft.cognitiveservices`
- //
- // Together with `model_version` and GraphProto.name, this forms the unique identity of
- // the graph.
- optional string domain = 4;
-
- // The version of the graph encoded. See Version enum below.
- optional int64 model_version = 5;
-
- // A human-readable documentation for this model. Markdown is allowed.
- optional string doc_string = 6;
-
- // The parameterized graph that is evaluated to execute the model.
- optional GraphProto graph = 7;
-
- // Named metadata values; keys should be distinct.
- repeated StringStringEntryProto metadata_props = 14;
-};
-
-// StringStringEntryProto follows the pattern for cross-proto-version maps.
-// See https://developers.google.com/protocol-buffers/docs/proto3#maps
-message StringStringEntryProto {
- optional string key = 1;
- optional string value= 2;
-};
-
-// Graphs
-//
-// A graph defines the computational logic of a model and is comprised of a parameterized
-// list of nodes that form a directed acyclic graph based on their inputs and outputs.
-// This is the equivalent of the "network" or "graph" in many deep learning
-// frameworks.
-message GraphProto {
- // The nodes in the graph, sorted topologically.
- repeated NodeProto node = 1;
-
- // The name of the graph.
- optional string name = 2; // namespace Graph
-
- // A list of named tensor values, used to specify constant inputs of the graph.
- // Each TensorProto entry must have a distinct name (within the list) that
- // also appears in the input list.
- repeated TensorProto initializer = 5;
-
- // A human-readable documentation for this graph. Markdown is allowed.
- optional string doc_string = 10;
-
- // The inputs and outputs of the graph.
- repeated ValueInfoProto input = 11;
- repeated ValueInfoProto output = 12;
-
- // Information for the values in the graph. The ValueInfoProto.name's
- // must be distinct. It is optional for a value to appear in value_info list.
- repeated ValueInfoProto value_info = 13;
-
- // DO NOT USE the following fields, they were deprecated from earlier versions.
- // repeated string input = 3;
- // repeated string output = 4;
- // optional int64 ir_version = 6;
- // optional int64 producer_version = 7;
- // optional string producer_tag = 8;
- // optional string domain = 9;
-}
-
-// Tensors
-//
-// A serialized tensor value.
-message TensorProto {
- enum DataType {
- UNDEFINED = 0;
- // Basic types.
- FLOAT = 1; // float
- UINT8 = 2; // uint8_t
- INT8 = 3; // int8_t
- UINT16 = 4; // uint16_t
- INT16 = 5; // int16_t
- INT32 = 6; // int32_t
- INT64 = 7; // int64_t
- STRING = 8; // string
- BOOL = 9; // bool
-
- // Advanced types
- FLOAT16 = 10;
- DOUBLE = 11;
- UINT32 = 12;
- UINT64 = 13;
- COMPLEX64 = 14; // complex with float32 real and imaginary components
- COMPLEX128 = 15; // complex with float64 real and imaginary components
- // Future extensions go here.
- }
-
- // The shape of the tensor.
- repeated int64 dims = 1;
-
- // The data type of the tensor.
- optional DataType data_type = 2;
-
- // For very large tensors, we may want to store them in chunks, in which
- // case the following fields will specify the segment that is stored in
- // the current TensorProto.
- message Segment {
- optional int64 begin = 1;
- optional int64 end = 2;
- }
- optional Segment segment = 3;
-
- // Tensor content must be organized in row-major order.
- //
- // Depending on the data_type field, exactly one of the fields below with
- // name ending in _data is used to store the elements of the tensor.
-
- // For float and complex64 values
- // Complex64 tensors are encoded as a single array of floats,
- // with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
- // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
- // is encoded as [1.0, 2.0 ,3.0 ,4.0]
- // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
- repeated float float_data = 4 [packed = true];
-
- // For int32, uint8, int8, uint16, int16, bool, and float16 values
- // float16 values must be bit-wise converted to an uint16_t prior
- // to writing to the buffer.
- // When this field is present, the data_type field MUST be
- // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
- repeated int32 int32_data = 5 [packed = true];
-
- // For strings.
- // Each element of string_data is a UTF-8 encoded Unicode
- // string. No trailing null, no leading BOM. The protobuf "string"
- // scalar type is not used to match ML community conventions.
- // When this field is present, the data_type field MUST be STRING
- repeated bytes string_data = 6;
-
- // For int64.
- // When this field is present, the data_type field MUST be INT64
- repeated int64 int64_data = 7 [packed = true];
-
- // Optionally, a name for the tensor.
- optional string name = 8; // namespace Value
-
- // A human-readable documentation for this tensor. Markdown is allowed.
- optional string doc_string = 12;
-
- // Serializations can either use one of the fields above, or use this
- // raw bytes field. The only exception is the string case, where one is
- // required to store the content in the repeated bytes string_data field.
- //
- // When this raw_data field is used to store tensor value, elements MUST
- // be stored in as fixed-width, little-endian order.
- // Floating-point data types MUST be stored in IEEE 754 format.
- // Complex64 elements must be written as two consecutive FLOAT values, real component first.
- // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
- // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
- //
- // Note: the advantage of specific field rather than the raw_data field is
- // that in some cases (e.g. int data), protobuf does a better packing via
- // variable length storage, and may lead to smaller binary footprint.
- // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
- optional bytes raw_data = 9;
-
- // For double
- // Complex64 tensors are encoded as a single array of doubles,
- // with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
- // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
- // is encoded as [1.0, 2.0 ,3.0 ,4.0]
- // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
- repeated double double_data = 10 [packed = true];
-
- // For uint64 and uint32 values
- // When this field is present, the data_type field MUST be
- // UINT32 or UINT64
- repeated uint64 uint64_data = 11 [packed = true];
-}
-
-// Defines a tensor shape. A dimension can be either an integer value
-// or a symbolic variable. A symbolic variable represents an unknown
-// dimension.
-message TensorShapeProto {
- message Dimension {
- oneof value {
- int64 dim_value = 1;
- string dim_param = 2; // namespace Shape
- };
- // Standard denotation can optionally be used to denote tensor
- // dimensions with standard semantic descriptions to ensure
- // that operations are applied to the correct axis of a tensor.
- optional string denotation = 3;
- };
- repeated Dimension dim = 1;
-}
-
-// A set of pre-defined constants to be used as values for
-// the standard denotation field in TensorShapeProto.Dimension
-// for semantic description of the tensor dimension.
-message DenotationConstProto {
- // Describe a batch number dimension.
- optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
- // Describe a channel dimension.
- optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
- // Describe a time dimension.
- optional string DATA_TIME = 3 [default = "DATA_TIME"];
- // Describe a feature dimension. This is typically a feature
- // dimension in RNN and/or spatial dimension in CNN.
- optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
- // Describe a filter in-channel dimension. This is the dimension
- // that is identical (in size) to the channel dimension of the input
- // image feature maps.
- optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
- // Describe a filter out channel dimension. This is the dimension
- // that is identical (int size) to the channel dimension of the output
- // image feature maps.
- optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
- // Describe a filter spatial dimension.
- optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
-}
-
-// Types
-//
-// The standard ONNX data types.
-message TypeProto {
-
- message Tensor {
- // This field MUST NOT have the value of UNDEFINED
- // This field MUST be present for this version of the IR.
- optional TensorProto.DataType elem_type = 1;
- optional TensorShapeProto shape = 2;
- }
-
-
- oneof value {
- // The type of a tensor.
- Tensor tensor_type = 1;
-
- }
-}
-
-// Operator Sets
-//
-// OperatorSets are uniquely identified by a (domain, opset_version) pair.
-message OperatorSetIdProto {
- // The domain of the operator set being identified.
- // The empty string ("") or absence of this field implies the operator
- // set that is defined as part of the ONNX specification.
- // This field MUST be present in this version of the IR when referring to any other operator set.
- optional string domain = 1;
-
- // The version of the operator set being identified.
- // This field MUST be present in this version of the IR.
- optional int64 version = 2;
-} \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
index cc73f2daff5..e9575af6010 100644
--- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
+++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
@@ -18,7 +18,7 @@ search test {
}
function mnist_softmax_onnx() {
- expression: onnx_vespa("mnist_softmax")
+ expression: onnx("mnist_softmax")
}
function my_xgboost() {
diff --git a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py
deleted file mode 100755
index 55df3a557e9..00000000000
--- a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-import onnx
-from onnx import helper, TensorProto
-
-INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "sequence"])
-OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "sequence"])
-
-nodes = [helper.make_node('Identity', ['input'], ['output'])]
-graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT])
-model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py')
-onnx.save(model_def, 'dynamic_model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/create_model.py b/config-model/src/test/integration/onnx-model/files/create_model.py
deleted file mode 100755
index 10ff92c2eda..00000000000
--- a/config-model/src/test/integration/onnx-model/files/create_model.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-import onnx
-from onnx import helper, TensorProto
-
-INPUT_1 = helper.make_tensor_value_info('first_input', TensorProto.FLOAT, [2])
-INPUT_2 = helper.make_tensor_value_info('second/input:0', TensorProto.FLOAT, [2])
-INPUT_3 = helper.make_tensor_value_info('third_input', TensorProto.FLOAT, [2])
-OUTPUT_1 = helper.make_tensor_value_info('path/to/output:0', TensorProto.FLOAT, [2])
-OUTPUT_2 = helper.make_tensor_value_info('path/to/output:1', TensorProto.FLOAT, [2])
-OUTPUT_3 = helper.make_tensor_value_info('path/to/output:2', TensorProto.FLOAT, [2])
-
-nodes = [
- helper.make_node(
- 'Add',
- ['first_input', 'second/input:0'],
- ['path/to/output:0'],
- ),
- helper.make_node(
- 'Add',
- ['third_input', 'second/input:0'],
- ['path/to/output:1']
- ),
- helper.make_node(
- 'Add',
- ['path/to/output:0', 'path/to/output:1'],
- ['path/to/output:2']
- ),
-]
-graph_def = helper.make_graph(
- nodes,
- 'simple_scoring',
- [INPUT_1, INPUT_2, INPUT_3],
- [OUTPUT_1, OUTPUT_2, OUTPUT_3]
-)
-model_def = helper.make_model(graph_def, producer_name='create_model.py')
-onnx.save(model_def, 'model.onnx')
diff --git a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx
deleted file mode 100644
index 6bbdad2d76e..00000000000
--- a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx
+++ /dev/null
@@ -1,13 +0,0 @@
-create_dynamic_model.py:x
-
-inputoutput"Identitysimple_scoringZ$
-input
-
-batch
-
-sequenceb%
-output
-
-batch
-
-sequenceB \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/files/model.onnx b/config-model/src/test/integration/onnx-model/files/model.onnx
deleted file mode 100644
index f3898205c6a..00000000000
--- a/config-model/src/test/integration/onnx-model/files/model.onnx
+++ /dev/null
@@ -1,34 +0,0 @@
-create_model.py:í
-4
- first_input
-second/input:0path/to/output:0"Add
-4
- third_input
-second/input:0path/to/output:1"Add
-;
-path/to/output:0
-path/to/output:1path/to/output:2"Addsimple_scoringZ
- first_input
-
-
-Z
-second/input:0
-
-
-Z
- third_input
-
-
-b
-path/to/output:0
-
-
-b
-path/to/output:1
-
-
-b
-path/to/output:2
-
-
-B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/files/summary_model.onnx b/config-model/src/test/integration/onnx-model/files/summary_model.onnx
deleted file mode 100644
index f3898205c6a..00000000000
--- a/config-model/src/test/integration/onnx-model/files/summary_model.onnx
+++ /dev/null
@@ -1,34 +0,0 @@
-create_model.py:í
-4
- first_input
-second/input:0path/to/output:0"Add
-4
- third_input
-second/input:0path/to/output:1"Add
-;
-path/to/output:0
-path/to/output:1path/to/output:2"Addsimple_scoringZ
- first_input
-
-
-Z
-second/input:0
-
-
-Z
- third_input
-
-
-b
-path/to/output:0
-
-
-b
-path/to/output:1
-
-
-b
-path/to/output:2
-
-
-B \ No newline at end of file
diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
index 6e9ba356293..0f0fa694e6f 100644
--- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
+++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
@@ -14,7 +14,7 @@ search test {
}
onnx-model my_model {
- file: files/model.onnx
+ file: files/ranking_model.onnx
input first_input: attribute(document_field)
input "second/input:0": constant(my_constant)
input "third_input": my_function
@@ -22,25 +22,19 @@ search test {
}
onnx-model another_model {
- file: files/model.onnx
+ file: files/ranking_model.onnx
input first_input: attribute(document_field)
input "second/input:0": constant(my_constant)
input "third_input": another_function
output "path/to/output:2": out
}
- onnx-model dynamic_model {
- file: files/dynamic_model.onnx
- input input: my_function
- output output: my_output
- }
-
rank-profile test_model_config {
function my_function() {
expression: tensor(d0[2])(1)
}
first-phase {
- expression: onnxModel(my_model).out{d0:1}
+ expression: onnxModel(my_model).out
}
}
@@ -55,7 +49,7 @@ search test {
expression: my_function()
}
first-phase {
- expression: onnxModel("files/model.onnx", "path/to/output:1"){d0:1}
+ expression: onnxModel("files/ranking_model.onnx", "path/to/output:1")
}
}
@@ -68,29 +62,9 @@ search test {
}
summary-features {
onnxModel(another_model).out
- onnxModel("files/summary_model.onnx", "path/to/output:2")
- }
- }
-
- rank-profile test_dynamic_model {
- function my_function() {
- expression: tensor(d0[1],d1[2])(d1)
- }
- first-phase {
- expression: onnxModel(dynamic_model){d0:0,d1:1}
+ onnxModel("files/ranking_model.onnx", "path/to/output:2")
}
- }
- rank-profile test_dynamic_model_2 {
- function my_function_2() {
- expression: tensor(d0[1],d1[3])(d1)
- }
- function my_function() {
- expression: my_function_2()
- }
- first-phase {
- expression: onnxModel(dynamic_model){d0:0,d1:2}
- }
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index 5060aafb55f..d9b0c70dfdd 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -25,61 +25,43 @@ public class RankingExpressionWithOnnxModelTestCase {
OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
((OnnxModelsConfig.Producer) db).getConfig(builder);
OnnxModelsConfig config = new OnnxModelsConfig(builder);
- assertEquals(5, config.model().size());
+ assertEquals(3, config.model().size());
- assertEquals("my_model", config.model(0).name());
- assertEquals(3, config.model(0).input().size());
- assertEquals("second/input:0", config.model(0).input(0).name());
- assertEquals("constant(my_constant)", config.model(0).input(0).source());
- assertEquals("first_input", config.model(0).input(1).name());
- assertEquals("attribute(document_field)", config.model(0).input(1).source());
- assertEquals("third_input", config.model(0).input(2).name());
- assertEquals("rankingExpression(my_function)", config.model(0).input(2).source());
- assertEquals(3, config.model(0).output().size());
- assertEquals("path/to/output:0", config.model(0).output(0).name());
- assertEquals("out", config.model(0).output(0).as());
- assertEquals("path/to/output:1", config.model(0).output(1).name());
- assertEquals("path_to_output_1", config.model(0).output(1).as());
- assertEquals("path/to/output:2", config.model(0).output(2).name());
- assertEquals("path_to_output_2", config.model(0).output(2).as());
-
- assertEquals("files_model_onnx", config.model(1).name());
+ assertEquals("my_model", config.model(1).name());
assertEquals(3, config.model(1).input().size());
- assertEquals(3, config.model(1).output().size());
+ assertEquals("first_input", config.model(1).input(0).name());
+ assertEquals("attribute(document_field)", config.model(1).input(0).source());
+ assertEquals("second/input:0", config.model(1).input(1).name());
+ assertEquals("constant(my_constant)", config.model(1).input(1).source());
+ assertEquals("third_input", config.model(1).input(2).name());
+ assertEquals("rankingExpression(my_function)", config.model(1).input(2).source());
+ assertEquals(1, config.model(1).output().size());
assertEquals("path/to/output:0", config.model(1).output(0).name());
- assertEquals("path_to_output_0", config.model(1).output(0).as());
- assertEquals("path/to/output:1", config.model(1).output(1).name());
- assertEquals("path_to_output_1", config.model(1).output(1).as());
- assertEquals("path/to/output:2", config.model(1).output(2).name());
- assertEquals("path_to_output_2", config.model(1).output(2).as());
- assertEquals("files_model_onnx", config.model(1).name());
+ assertEquals("out", config.model(1).output(0).as());
+
+ assertEquals("files_ranking_model_onnx", config.model(0).name());
+ assertEquals(0, config.model(0).input().size());
+ assertEquals(2, config.model(0).output().size());
+ assertEquals("path/to/output:1", config.model(0).output(0).name());
+ assertEquals("path_to_output_1", config.model(0).output(0).as());
+ assertEquals("path/to/output:2", config.model(0).output(1).name());
+ assertEquals("path_to_output_2", config.model(0).output(1).as());
assertEquals("another_model", config.model(2).name());
assertEquals("third_input", config.model(2).input(2).name());
assertEquals("rankingExpression(another_function)", config.model(2).input(2).source());
-
- assertEquals("files_summary_model_onnx", config.model(3).name());
- assertEquals(3, config.model(3).input().size());
- assertEquals(3, config.model(3).output().size());
-
- assertEquals("dynamic_model", config.model(4).name());
- assertEquals(1, config.model(4).input().size());
- assertEquals(1, config.model(4).output().size());
- assertEquals("rankingExpression(my_function)", config.model(4).input(0).source());
}
private void assertTransformedFeature(DocumentDatabase db) {
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);
- assertEquals(7, config.rankprofile().size());
+ assertEquals(5, config.rankprofile().size());
assertEquals("test_model_config", config.rankprofile(2).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name());
assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name());
- assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value());
- assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name());
- assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value());
+ assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value());
assertEquals("test_generated_model_config", config.rankprofile(3).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name());
@@ -87,28 +69,16 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name());
assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name());
assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name());
- assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value());
- assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name());
- assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value());
+ assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value());
assertEquals("test_summary_features", config.rankprofile(4).name());
assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name());
assertEquals("1", config.rankprofile(4).fef().property(3).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name());
- assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value());
+ assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name());
- assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value());
-
- assertEquals("test_dynamic_model", config.rankprofile(5).name());
- assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name());
- assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name());
- assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value());
-
- assertEquals("test_dynamic_model_2", config.rankprofile(6).name());
- assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name());
- assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value());
-
+ assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 40bf970a313..6bf69907609 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -70,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
- "onnx_vespa('mnist_softmax.onnx')",
+ "onnx('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -87,7 +87,7 @@ public class RankingExpressionWithOnnxTestCase {
queryProfile,
queryProfileType);
RankProfileSearchFixture search = fixtureWith("query(mytensor)",
- "onnx_vespa('mnist_softmax.onnx')",
+ "onnx('mnist_softmax.onnx')",
null,
null,
"Placeholder",
@@ -99,7 +99,7 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithDocumentFeature() {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
- "onnx_vespa('mnist_softmax.onnx')",
+ "onnx('mnist_softmax.onnx')",
null,
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
@@ -117,7 +117,7 @@ public class RankingExpressionWithOnnxTestCase {
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType);
RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
- "onnx_vespa('mnist_softmax.onnx')",
+ "onnx('mnist_softmax.onnx')",
"constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }",
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
@@ -129,21 +129,21 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testNestedOnnxReference() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "5 + sum(onnx_vespa('mnist_softmax.onnx'))");
+ "5 + sum(onnx('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutput() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx', 'layer_add')");
+ "onnx('mnist_softmax.onnx', 'layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
public void testOnnxReferenceWithSpecifiedOutputAndSignature() {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')");
+ "onnx('mnist_softmax.onnx', 'default.layer_add')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -155,7 +155,7 @@ public class RankingExpressionWithOnnxTestCase {
new QueryProfileRegistry(),
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: onnx_vespa('mnist_softmax.onnx')" +
+ " expression: onnx('mnist_softmax.onnx')" +
" }\n" +
" }");
search.compileRankProfile("my_profile", applicationDir.append("models"));
@@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase {
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx_vespa('mnist_softmax.onnx'): " +
+ "onnx('mnist_softmax.onnx'): " +
"Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
@@ -175,13 +175,13 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithWrongFunctionType() {
try {
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)",
- "onnx_vespa('mnist_softmax.onnx')");
+ "onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx_vespa('mnist_softmax.onnx'): " +
+ "onnx('mnist_softmax.onnx'): " +
"Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " +
"but this function returns tensor(d0[1],d5[10])",
Exceptions.toMessageString(expected));
@@ -192,13 +192,13 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceSpecifyingNonExistingOutput() {
try {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx', 'y')");
+ "onnx('mnist_softmax.onnx', 'y')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
- "onnx_vespa('mnist_softmax.onnx','y'): " +
+ "onnx('mnist_softmax.onnx','y'): " +
"No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add",
Exceptions.toMessageString(expected));
}
@@ -207,7 +207,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testImportingFromStoredExpressions() throws IOException {
RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx')");
+ "onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
@@ -218,7 +218,7 @@ public class RankingExpressionWithOnnxTestCase {
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
- "onnx_vespa('mnist_softmax.onnx')",
+ "onnx('mnist_softmax.onnx')",
null,
null,
"Placeholder",
@@ -243,7 +243,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d1[10],d2[784])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx_vespa('mnist_softmax.onnx')" +
+ " expression: onnx('mnist_softmax.onnx')" +
" }\n" +
" }" +
" rank-profile my_profile_child inherits my_profile {\n" +
@@ -288,7 +288,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d0[3])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx_vespa('" + name + ".onnx')" +
+ " expression: onnx('" + name + ".onnx')" +
" }\n" +
" }";
final String functionName = "imported_ml_function_" + name + "_exp_output";
@@ -310,7 +310,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: tensor<float>(d0[3])(0.0)\n" +
" }\n" +
" first-phase {\n" +
- " expression: onnx_vespa('" + name + ".onnx')" +
+ " expression: onnx('" + name + ".onnx')" +
" }\n" +
" }" +
" rank-profile my_profile_child inherits my_profile {\n" +
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index 181ef6dffbd..4beaf6086a6 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -221,10 +221,5 @@
<artifactId>jdisc_http_service</artifactId>
<version>${project.version}</version>
</dependency>
- <dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- <version>${protobuf.version}</version>
- </dependency>
</dependencies>
</project>