summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java27
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java219
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java19
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java141
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java97
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java154
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java2
12 files changed, 102 insertions, 582 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index b153ff62e7d..4011ce43841 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -159,12 +158,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
}
- // A reference to an ONNX model?
- Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
- if (onnxFeatureType.isPresent()) {
- return onnxFeatureType.get();
- }
-
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -217,26 +210,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(function);
}
- private Optional<TensorType> onnxFeatureType(Reference reference) {
- if ( ! reference.name().equals("onnxModel"))
- return Optional.empty();
-
- if ( ! featureTypes.containsKey(reference)) {
- String configOrFileName = reference.arguments().expressions().get(0).toString();
-
- // Look up standardized format as added in RankProfile
- String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
- String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
-
- reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
- if ( ! featureTypes.containsKey(reference)) {
- throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
- }
- }
-
- return Optional.of(featureTypes.get(reference));
- }
-
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 58213186f78..c2fb2107604 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -2,22 +2,14 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
-import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
-import onnx.Onnx;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
+import java.util.List;
import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -29,16 +21,16 @@ public class OnnxModel {
public enum PathType {FILE, URI};
private final String name;
- private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private String defaultOutput = null;
- private Map<String, String> inputMap = new HashMap<>();
- private Map<String, String> outputMap = new HashMap<>();
+ private List<OnnxNameMapping> inputMap = new ArrayList<>();
+ private List<OnnxNameMapping> outputMap = new ArrayList<>();
- private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
- private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
- private Map<String, TensorType> vespaTypes = new HashMap<>();
+ public PathType getPathType() {
+ return pathType;
+ }
+
+ private PathType pathType = PathType.FILE;
public OnnxModel(String name) {
this.name = name;
@@ -57,52 +49,21 @@ public class OnnxModel {
}
public void setUri(String uri) {
- throw new IllegalArgumentException("URI for ONNX models are not currently supported");
- }
-
- public PathType getPathType() {
- return pathType;
- }
-
- public void setDefaultOutput(String onnxName) {
- Objects.requireNonNull(onnxName, "Name cannot be null");
- this.defaultOutput = onnxName;
+ Objects.requireNonNull(uri, "uri cannot be null");
+ this.path = uri;
+ this.pathType = PathType.URI;
}
public void addInputNameMapping(String onnxName, String vespaName) {
- addInputNameMapping(onnxName, vespaName, true);
- }
-
- public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! inputMap.containsKey(onnxName)) {
- inputMap.put(onnxName, vespaName);
- }
+ this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
public void addOutputNameMapping(String onnxName, String vespaName) {
- addOutputNameMapping(onnxName, vespaName, true);
- }
-
- public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! outputMap.containsKey(onnxName)) {
- outputMap.put(onnxName, vespaName);
- }
- }
-
- public void addInputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- inputTypes.put(onnxName, type);
- }
-
- public void addOutputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- outputTypes.put(onnxName, type);
+ this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
/** Initiate sending of this constant to some services over file distribution */
@@ -115,16 +76,11 @@ public class OnnxModel {
public String getName() { return name; }
public String getFileName() { return path; }
- public Path getFilePath() { return Path.fromString(path); }
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
- public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
- public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
-
- public String getDefaultOutput() {
- return defaultOutput;
- }
+ public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
+ public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
public void validate() {
if (path == null || path.isEmpty())
@@ -134,142 +90,23 @@ public class OnnxModel {
public String toString() {
StringBuilder b = new StringBuilder();
b.append("onnx-model '").append(name)
- .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
- .append("' with ref '").append(fileReference)
- .append("'");
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
return b.toString();
}
- /**
- * Return the tensor type for an ONNX model for the given context.
- * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
- * type depends on the input types for the given context (rank profile).
- */
- public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
- Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
- if (onnxOutputType == null) {
- throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
- }
- if (containsSymbolicDimensionSizes(onnxOutputType)) {
- return getTensorTypeWithSymbolicDimensions(onnxOutputType, context);
- }
- return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
- }
-
- private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
- Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context);
- if (symbolicSizes.isEmpty()) {
- return TensorType.empty; // Context is probably a rank profile not using this ONNX model
- }
- return typeFrom(onnxOutputType, symbolicSizes);
- }
-
- private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) {
- Map<String, Long> symbolicSizes = new HashMap<>();
- for (String onnxInputName : inputTypes.keySet()) {
-
- Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
- if ( ! containsSymbolicDimensionSizes(onnxType)) {
- continue;
- }
-
- Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
- if (vespaType.isEmpty()) {
- return Collections.emptyMap();
- }
-
- var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
- var vespaDimensions = vespaType.get().dimensions();
- if (vespaDimensions.size() != onnxDimensions.size()) {
- return Collections.emptyMap();
- }
-
- for (int i = 0; i < vespaDimensions.size(); ++i) {
- if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) {
- continue;
- }
- String symbolicName = onnxDimensions.get(i).getDimParam();
- Long size = vespaDimensions.get(i).size().get();
- if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
- throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " +
- "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
- }
- symbolicSizes.put(symbolicName, size);
- }
- }
- return symbolicSizes;
- }
-
- private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
- String source = inputMap.get(onnxInputName);
- if (source != null) {
- // Source is either a simple reference (query/attribute/constant)...
- Optional<Reference> reference = Reference.simple(source);
- if (reference.isPresent()) {
- return Optional.of(context.getType(reference.get()));
- }
- // ... or a function
- ExpressionFunction func = context.getFunction(source);
- if (func != null) {
- return Optional.of(func.getBody().type(context));
- }
- }
- return Optional.empty(); // if this context does not contain this input
- }
-
- private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) {
- return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue());
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type) {
- return typeFrom(type, null);
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) {
- String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- long onnxDimensionSize = onnxDimension.getDimValue();
- if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
- onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
- }
- if (onnxDimensionSize == 0 && symbolicSizes != null) {
- // This is for the case where all symbolic dimensions have
- // different names, but can be resolved to a single dimension size.
- Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
- if (unknownSizes.size() == 1) {
- onnxDimensionSize = unknownSizes.iterator().next();
- }
- }
- if (onnxDimensionSize <= 0) {
- throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
- "ONNX type: " + type + " to Vespa tensor type.");
- }
- builder.indexed(dimensionName, onnxDimensionSize);
- }
- return builder.build();
- }
+ public static class OnnxNameMapping {
+ private String onnxName;
+ private String vespaName;
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT8: return TensorType.Value.FLOAT;
- case INT16: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.FLOAT;
- case UINT8: return TensorType.Value.FLOAT;
- case UINT16: return TensorType.Value.FLOAT;
- case UINT32: return TensorType.Value.FLOAT;
- case UINT64: return TensorType.Value.FLOAT;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
+ private OnnxNameMapping(String onnxName, String vespaName) {
+ this.onnxName = onnxName;
+ this.vespaName = vespaName;
}
+ public String getOnnxName() { return onnxName; }
+ public String getVespaName() { return vespaName; }
+ public void setVespaName(String vespaName) { this.vespaName = vespaName; }
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 96c043bdb34..d309f48d6df 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -18,7 +18,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
@@ -159,10 +158,6 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
- private Map<String, OnnxModel> onnxModels() {
- return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
- }
-
private Stream<ImmutableSDField> allFields() {
if (search == null) return Stream.empty();
if (allFieldsList == null) {
@@ -826,20 +821,6 @@ public class RankProfile implements Cloneable {
}
}
- // Add output types for ONNX models
- for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
- String modelName = entry.getKey();
- OnnxModel model = entry.getValue();
- Arguments args = new Arguments(new ReferenceNode(modelName));
-
- TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
- context.setType(new Reference("onnxModel", args, null), defaultOutputType);
-
- for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
- TensorType type = model.getTensorType(mapping.getKey(), context);
- context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
- }
- }
return context;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 22a32c8fd65..84442fedc48 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
- model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName())));
+ model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName())));
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 56a5d539906..87eaaf0387a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) {
- for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
- String source = mapping.getValue();
+ for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) {
+ String source = mapping.getVespaName();
if (functionNames.contains(source)) {
- onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")");
+ mapping.setVespaName("rankingExpression(" + source + ")");
}
}
}
@@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>();
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
- ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile);
+ ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch());
if (referenceNode != replacedNode) {
replacedSummaryFeatures.add(replacedNode);
i.remove();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index d23a8376e7a..ec517768ea9 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature;
+ if ( ! feature.getName().equals("onnx")) return feature;
try {
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index 69cdae10e47..e1ad003e5bd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,36 +1,20 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import com.yahoo.path.Path;
import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.vespa.model.ml.ConvertedModel;
-import com.yahoo.vespa.model.ml.FeatureArguments;
-import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;
/**
- * Transforms ONNX model features of the forms:
- *
- * onnxModel(config_name)
- * onnxModel(config_name).output
- * onnxModel("path/to/model")
- * onnxModel("path/to/model").output
- * onnxModel("path/to/model", "path/to/output")
- * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
- *
- * To the format expected by the backend:
- *
- * onnxModel(config_name).output
+ * Transforms instances of the onnxModel ranking feature and generates
+ * ONNX configuration if necessary.
*
* @author lesters
*/
@@ -49,92 +33,85 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
if (context.rankProfile() == null) return feature;
if (context.rankProfile().getSearch() == null) return feature;
- return transformFeature(feature, context.rankProfile());
+ return transformFeature(feature, context.rankProfile().getSearch());
}
- public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
- ImmutableSearch search = rankProfile.getSearch();
- final String featureName = feature.getName();
- if ( ! featureName.equals("onnxModel")) return feature;
+ public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
+ if (!feature.getName().equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " +
- "onnx-model config or an ONNX file.");
- if (arguments.expressions().size() > 3)
- throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
-
- // Check that the model configuration "onnx-model" exists. If not defined, it should have been added
- // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find
- // the actual ONNX file, which can happen if we are restarting or upgrading an application using an
- // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
-
- String modelConfigName = getModelConfigName(feature.reference());
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
+ throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
+ "onnx-model config or a ONNX file.");
+ if (arguments.expressions().size() > 2)
+ throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
+
+ // Validation that the file actually exists is handled when the file is added to file distribution.
+ // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
+
+ String modelConfigName;
+ OnnxModel onnxModel;
+ if (arguments.expressions().get(0) instanceof ReferenceNode) {
+ modelConfigName = arguments.expressions().get(0).toString();
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
+ }
+ } else if (arguments.expressions().get(0) instanceof ConstantNode) {
String path = asString(arguments.expressions().get(0));
- ModelName modelName = new ModelName(null, Path.fromString(path), true);
- ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile);
- FeatureArguments featureArguments = new FeatureArguments(arguments);
- return convertedModel.expression(featureArguments, null);
- }
-
- String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
- String output = getModelOutput(feature.reference(), defaultOutput);
- if (! onnxModel.getOutputMap().containsValue(output)) {
- throw new IllegalArgumentException(featureName + " argument '" + output +
- "' output not found in model '" + onnxModel.getFileName() + "'");
+ modelConfigName = asValidIdentifier(path);
+ onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ } else {
+ throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
}
- return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
- }
- public static String getModelConfigName(Reference reference) {
- if (reference.arguments().size() > 0) {
- ExpressionNode expr = reference.arguments().expressions().get(0);
- if (expr instanceof ReferenceNode) { // refers to onnx-model config
- return expr.toString();
+ String output = null;
+ if (feature.getOutput() != null) {
+ output = feature.getOutput();
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(output, output);
}
- if (expr instanceof ConstantNode) { // refers to an file path
- return asValidIdentifier(expr);
+ } else if (arguments.expressions().size() > 1) {
+ String name = asString(arguments.expressions().get(1));
+ output = asValidIdentifier(name);
+ if ( ! hasOutputMapping(onnxModel, output)) {
+ onnxModel.addOutputNameMapping(name, output);
}
}
- return null;
- }
- public static String getModelOutput(Reference reference, String defaultOutput) {
- if (reference.output() != null) {
- return reference.output();
- } else if (reference.arguments().expressions().size() == 2) {
- return asValidIdentifier(reference.arguments().expressions().get(1));
- } else if (reference.arguments().expressions().size() > 2) {
- return asValidIdentifier(reference.arguments().expressions().get(2));
- }
- return defaultOutput;
+ // Replace feature with name of config
+ ExpressionNode argument = new ReferenceNode(modelConfigName);
+ return new ReferenceNode("onnxModel", List.of(argument), output);
+
}
- public static String stripQuotes(String s) {
- if (isNotQuoteSign(s.codePointAt(0))) return s;
- if (isNotQuoteSign(s.codePointAt(s.length() - 1)))
- throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
- return s.substring(1, s.length()-1);
+ private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
+ return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
}
- public static String asValidIdentifier(String str) {
- return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ private static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
- private static String asValidIdentifier(ExpressionNode node) {
- return asValidIdentifier(asString(node));
+ private static String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
}
- private static boolean isNotQuoteSign(int c) {
- return c != '\'' && c != '"';
+ private static boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
}
- public static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ private static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
deleted file mode 100644
index afba88c135d..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ /dev/null
@@ -1,97 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchdefinition.processing;
-
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.application.api.DeployLogger;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankProfileRegistry;
-import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
-import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
-
-import java.util.Map;
-
-/**
- * Processes ONNX ranking features of the form:
- *
- * onnx("files/model.onnx", "path/to/output:1")
- *
- * And generates an "onnx-model" configuration as if it was defined in the schema:
- *
- * onnx-model files_model_onnx {
- * file: "files/model.onnx"
- * }
- *
- * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be
- * processed after this.
- *
- * @author lesters
- */
-public class OnnxModelConfigGenerator extends Processor {
-
- public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
- super(search, deployLogger, rankProfileRegistry, queryProfiles);
- }
-
- @Override
- public void process(boolean validate, boolean documentsOnly) {
- if (documentsOnly) return;
- for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) {
- if (profile.getFirstPhaseRanking() != null) {
- process(profile.getFirstPhaseRanking().getRoot());
- }
- if (profile.getSecondPhaseRanking() != null) {
- process(profile.getSecondPhaseRanking().getRoot());
- }
- for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
- process(function.getValue().function().getBody().getRoot());
- }
- for (ReferenceNode feature : profile.getSummaryFeatures()) {
- process(feature);
- }
- }
- }
-
- private void process(ExpressionNode node) {
- if (node instanceof ReferenceNode) {
- process((ReferenceNode)node);
- } else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode) node).children()) {
- process(child);
- }
- }
- }
-
- private void process(ReferenceNode feature) {
- if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
- if (feature.getArguments().size() > 0) {
- if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
- ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
- String path = OnnxModelTransformer.stripQuotes(node.sourceString());
- String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
-
- // Only add the configuration if the model can actually be found.
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
- path = ApplicationPackage.MODELS_DIR.append(path).toString();
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
- return;
- }
- }
-
- OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- search.onnxModels().add(onnxModel);
- }
- }
- }
- }
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
deleted file mode 100644
index bead2e7e7c9..00000000000
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchdefinition.processing;
-
-import com.yahoo.cloud.config.ConfigserverConfig;
-import com.yahoo.component.Version;
-import com.yahoo.config.FileReference;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.application.api.DeployLogger;
-import com.yahoo.config.application.api.FileRegistry;
-import com.yahoo.path.Path;
-import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.RankProfileRegistry;
-import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
-import com.yahoo.vespa.defaults.Defaults;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.file.Paths;
-import java.util.Map;
-import java.util.Optional;
-
-/**
- * Processes every "onnx-model" element in the schema. Parses the model file,
- * adds missing input and output mappings (assigning default names), and
- * adds tensor types to all model inputs and outputs.
- *
- * Must be processed before RankingExpressingTypeResolver.
- *
- * @author lesters
- */
-public class OnnxModelTypeResolver extends Processor {
-
- public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
- super(search, deployLogger, rankProfileRegistry, queryProfiles);
- }
-
- @Override
- public void process(boolean validate, boolean documentsOnly) {
- if (documentsOnly) return;
-
- for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
- OnnxModel modelConfig = entry.getValue();
- try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
-
- // Model inputs - if not defined, assumes a function is provided with a valid name
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
- String onnxInputName = valueInfo.getName();
- String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
- modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
- modelConfig.addInputType(onnxInputName, valueInfo.getType());
- }
-
- // Model outputs
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
- String onnxOutputName = valueInfo.getName();
- String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
- modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
- modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
- }
-
- // Set the first output as default
- if ( ! model.getGraph().getOutputList().isEmpty()) {
- modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
- }
-
- } catch (IOException e) {
- throw new IllegalArgumentException("Unable to parse ONNX model", e);
- }
- }
- }
-
- static boolean modelFileExists(String path, ApplicationPackage app) {
- Path pathInApplicationPackage = Path.fromString(path);
- if (getFile(pathInApplicationPackage, app).exists()) {
- return true;
- }
- if (getFileReference(pathInApplicationPackage, app).isPresent()) {
- return true;
- }
- return false;
- }
-
- private InputStream openModelFile(Path path) throws FileNotFoundException {
- ApplicationFile file;
- Optional<FileReference> reference;
- Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
-
- if ((file = getFile(path)).exists()) {
- return file.createInputStream();
- }
- if ((file = getFile(modelsPath)).exists()) {
- return file.createInputStream();
- }
- if ((reference = getFileReference(path)).isPresent()) {
- return openFromFileRepository(path, reference.get());
- }
- if ((reference = getFileReference(modelsPath)).isPresent()) {
- return openFromFileRepository(modelsPath, reference.get());
- }
-
- throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
- "in application package or file repository.");
- }
-
- private ApplicationFile getFile(Path path) {
- return getFile(path, search.applicationPackage());
- }
-
- private static ApplicationFile getFile(Path path, ApplicationPackage app) {
- return app.getFile(path);
- }
-
- private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
- return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
- }
-
- public static String getFileRepositoryPath(Path path, String fileReference) {
- ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
- String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
- return Paths.get(fileRefDir, fileReference, path.getName()).toString();
- }
-
- private Optional<FileReference> getFileReference(Path path) {
- return getFileReference(path, search.applicationPackage());
- }
-
- private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
- Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
- if (fileRegistry.isPresent()) {
- for (FileRegistry.Entry file : fileRegistry.get().export()) {
- if (file.relativePath.equals(path.toString())) {
- return Optional.of(file.reference);
- }
- }
- }
- return Optional.empty();
- }
-
- private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
- if (app == null) return Optional.empty();
- Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
- return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
- }
-
-}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index 1a3ef9e54b4..e8594c2a87f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -74,8 +74,6 @@ public class Processing {
ReferenceFieldsProcessor::new,
FastAccessValidator::new,
ReservedFunctionNames::new,
- OnnxModelConfigGenerator::new,
- OnnxModelTypeResolver::new,
RankingExpressionTypeResolver::new,
// These should be last:
IndexingValidation::new,
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index d5c5183b01f..c6c7969e466 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -1,19 +1,20 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation;
+import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.log.InvalidLogFormatException;
import java.util.logging.Level;
import com.yahoo.log.LogMessage;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver;
import com.yahoo.yolean.Exceptions;
import com.yahoo.system.ProcessExecuter;
import com.yahoo.text.StringUtilities;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.collections.Pair;
import com.yahoo.config.ConfigInstance;
+import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.config.search.ImportedFieldsConfig;
import com.yahoo.vespa.config.search.IndexschemaConfig;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -30,6 +31,7 @@ import com.yahoo.vespa.model.search.SearchCluster;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
+import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.logging.Logger;
@@ -150,9 +152,12 @@ public class RankSetupValidator extends Validator {
// Assist verify-ranksetup in finding the actual ONNX model files
Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap();
if (models.values().size() > 0) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
List<String> config = new ArrayList<>(models.values().size() * 2);
for (OnnxModel model : models.values()) {
- String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference());
+ String modelFilename = Paths.get(model.getFileName()).getFileName().toString();
+ String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString();
config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference()));
config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath));
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 5ee6ed02e61..943fcbf6c1d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -150,7 +150,7 @@ public class ConvertedModel {
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
ExpressionFunction expression = selectExpression(arguments);
- if (sourceModel.isPresent() && context != null) // we should verify
+ if (sourceModel.isPresent()) // we should verify
verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
return expression.getBody().getRoot();
}