summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java27
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java228
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java19
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java141
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java97
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java154
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java2
10 files changed, 588 insertions, 94 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 4011ce43841..b153ff62e7d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
}
+ // A reference to an ONNX model?
+ Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
+ if (onnxFeatureType.isPresent()) {
+ return onnxFeatureType.get();
+ }
+
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(function);
}
+ private Optional<TensorType> onnxFeatureType(Reference reference) {
+ if ( ! reference.name().equals("onnxModel"))
+ return Optional.empty();
+
+ if ( ! featureTypes.containsKey(reference)) {
+ String configOrFileName = reference.arguments().expressions().get(0).toString();
+
+ // Look up standardized format as added in RankProfile
+ String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
+ String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
+
+ reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
+ if ( ! featureTypes.containsKey(reference)) {
+ throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
+ }
+ }
+
+ return Optional.of(featureTypes.get(reference));
+ }
+
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index c2fb2107604..5e8b8579ee6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -2,14 +2,22 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
+import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
+import onnx.Onnx;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.List;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -21,16 +29,16 @@ public class OnnxModel {
public enum PathType {FILE, URI};
private final String name;
+ private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private List<OnnxNameMapping> inputMap = new ArrayList<>();
- private List<OnnxNameMapping> outputMap = new ArrayList<>();
-
- public PathType getPathType() {
- return pathType;
- }
+ private String defaultOutput = null;
+ private Map<String, String> inputMap = new HashMap<>();
+ private Map<String, String> outputMap = new HashMap<>();
- private PathType pathType = PathType.FILE;
+ private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
+ private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
+ private Map<String, TensorType> vespaTypes = new HashMap<>();
public OnnxModel(String name) {
this.name = name;
@@ -49,21 +57,52 @@ public class OnnxModel {
}
public void setUri(String uri) {
- Objects.requireNonNull(uri, "uri cannot be null");
- this.path = uri;
- this.pathType = PathType.URI;
+ throw new IllegalArgumentException("URI for ONNX models are not currently supported");
+ }
+
+ public PathType getPathType() {
+ return pathType;
+ }
+
+ public void setDefaultOutput(String onnxName) {
+ Objects.requireNonNull(onnxName, "Name cannot be null");
+ this.defaultOutput = onnxName;
}
public void addInputNameMapping(String onnxName, String vespaName) {
+ addInputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ if (overwrite || ! inputMap.containsKey(onnxName)) {
+ inputMap.put(onnxName, vespaName);
+ }
}
public void addOutputNameMapping(String onnxName, String vespaName) {
+ addOutputNameMapping(onnxName, vespaName, true);
+ }
+
+ public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ if (overwrite || ! outputMap.containsKey(onnxName)) {
+ outputMap.put(onnxName, vespaName);
+ }
+ }
+
+ public void addInputType(String onnxName, Onnx.TypeProto type) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(type, "Tensor type cannot be null");
+ inputTypes.put(onnxName, type);
+ }
+
+ public void addOutputType(String onnxName, Onnx.TypeProto type) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(type, "Tensor type cannot be null");
+ outputTypes.put(onnxName, type);
}
/** Initiate sending of this constant to some services over file distribution */
@@ -76,11 +115,16 @@ public class OnnxModel {
public String getName() { return name; }
public String getFileName() { return path; }
+ public Path getFilePath() { return Path.fromString(path); }
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
- public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
- public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
+ public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
+ public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
+
+ public String getDefaultOutput() {
+ return defaultOutput;
+ }
public void validate() {
if (path == null || path.isEmpty())
@@ -90,23 +134,151 @@ public class OnnxModel {
public String toString() {
StringBuilder b = new StringBuilder();
b.append("onnx-model '").append(name)
- .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
- .append("' with ref '").append(fileReference)
- .append("'");
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
return b.toString();
}
- public static class OnnxNameMapping {
- private String onnxName;
- private String vespaName;
+ /**
+ * Return the tensor type for an ONNX model for the given context.
+ * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
+ * type depends on the input types for the given context (rank profile).
+ */
+ public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
+ Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
+ if (onnxOutputType == null) {
+ throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
+ }
+ if (allDimensionSizesAreKnown(onnxOutputType)) {
+ return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
+ }
+ return getTensorTypeWithUnknownDimensions(onnxOutputType, context);
+ }
+
+ private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) {
+ return type.getTensorType().getShape().getDimList().stream().noneMatch(d ->
+ (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1);
+ }
+
+ private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
+ long unboundSize = 0;
+ Map<String, Long> symbolicSizes = new HashMap<>();
+
+ for (String onnxInputName : inputTypes.keySet()) {
+ Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
+ if (allDimensionSizesAreKnown(onnxType)) {
+ continue;
+ }
+
+ Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
+ if (vespaType.isEmpty()) {
+ return TensorType.empty;
+ }
+
+ var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
+ var vespaDimensions = vespaType.get().dimensions();
+ if (vespaDimensions.size() != onnxDimensions.size()) {
+ return TensorType.empty;
+ }
+
+ for (int i = 0; i < vespaDimensions.size(); ++i) {
+ if (vespaDimensions.get(i).size().isEmpty()) {
+ continue;
+ }
+ Long size = vespaDimensions.get(i).size().get();
+
+ // Handle dimensions with size -1 - typically batch dimensions
+ if (onnxDimensions.get(i).getDimValue() == -1) {
+ if (unboundSize != 0 && unboundSize != size) {
+ throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " +
+ "for type '" + onnxOutputType + "' in ONNX model '" + name + "'");
+ }
+ unboundSize = size;
+
+ // Handle dimensions with symbolic names
+ } else if (onnxDimensions.get(i).hasDimParam()) {
+ String symbolicName = onnxDimensions.get(i).getDimParam();
+ if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
+ throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" +
+ symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
+ }
+ symbolicSizes.put(symbolicName, size);
+ }
+ }
+ }
+ return typeFrom(onnxOutputType, symbolicSizes, unboundSize);
+ }
+
+ private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
+ String source = inputMap.get(onnxInputName);
+ if (source != null) {
+ // Source is either a simple reference (query/attribute/constant)...
+ Optional<Reference> reference = Reference.simple(source);
+ if (reference.isPresent()) {
+ return Optional.of(context.getType(reference.get()));
+ }
+ // ... or a function
+ ExpressionFunction func = context.getFunction(source);
+ if (func != null) {
+ return Optional.of(func.getBody().type(context));
+ }
+ }
+ return Optional.empty(); // if this context does not contain this input
+ }
+
+ private static TensorType typeFrom(Onnx.TypeProto type) {
+ return typeFrom(type, null, 0);
+ }
+
+ private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ long onnxDimensionSize = onnxDimension.getDimValue();
+ if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
+ onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
+ }
+ if (onnxDimensionSize == 0 && symbolicSizes != null) {
+ // This is for the case where all symbolic dimensions have
+ // different names, but can be resolved to a single dimension size.
+ Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
+ if (unknownSizes.size() == 1) {
+ onnxDimensionSize = unknownSizes.iterator().next();
+ }
+ }
+ if (onnxDimensionSize < 0) {
+ onnxDimensionSize = unboundSize;
+ }
+ if (onnxDimensionSize <= 0) {
+ throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
+ "ONNX type: " + type + " to Vespa tensor type.");
+ }
+ builder.indexed(dimensionName, onnxDimensionSize);
+ }
+ return builder.build();
+ }
- private OnnxNameMapping(String onnxName, String vespaName) {
- this.onnxName = onnxName;
- this.vespaName = vespaName;
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.FLOAT;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
}
- public String getOnnxName() { return onnxName; }
- public String getVespaName() { return vespaName; }
- public void setVespaName(String vespaName) { this.vespaName = vespaName; }
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index d309f48d6df..96c043bdb34 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
@@ -158,6 +159,10 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
+ private Map<String, OnnxModel> onnxModels() {
+ return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
+ }
+
private Stream<ImmutableSDField> allFields() {
if (search == null) return Stream.empty();
if (allFieldsList == null) {
@@ -821,6 +826,20 @@ public class RankProfile implements Cloneable {
}
}
+ // Add output types for ONNX models
+ for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
+ String modelName = entry.getKey();
+ OnnxModel model = entry.getValue();
+ Arguments args = new Arguments(new ReferenceNode(modelName));
+
+ TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
+ context.setType(new Reference("onnxModel", args, null), defaultOutputType);
+
+ for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
+ TensorType type = model.getTensorType(mapping.getKey(), context);
+ context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
+ }
+ }
return context;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 84442fedc48..22a32c8fd65 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName())));
- model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName())));
+ model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
+ model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 87eaaf0387a..56a5d539906 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<String> functionNames = rankProfile.getFunctions().keySet();
if (functionNames.isEmpty()) return;
for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) {
- for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) {
- String source = mapping.getVespaName();
+ for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) {
+ String source = mapping.getValue();
if (functionNames.contains(source)) {
- mapping.setVespaName("rankingExpression(" + source + ")");
+ onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")");
}
}
}
@@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>();
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
- ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch());
+ ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile);
if (referenceNode != replacedNode) {
replacedSummaryFeatures.add(replacedNode);
i.remove();
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index ec517768ea9..d23a8376e7a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if ( ! feature.getName().equals("onnx")) return feature;
+ if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature;
try {
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index e1ad003e5bd..69cdae10e47 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -1,20 +1,36 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.yahoo.path.Path;
import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.vespa.model.ml.ConvertedModel;
+import com.yahoo.vespa.model.ml.FeatureArguments;
+import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;
/**
- * Transforms instances of the onnxModel ranking feature and generates
- * ONNX configuration if necessary.
+ * Transforms ONNX model features of the forms:
+ *
+ * onnxModel(config_name)
+ * onnxModel(config_name).output
+ * onnxModel("path/to/model")
+ * onnxModel("path/to/model").output
+ * onnxModel("path/to/model", "path/to/output")
+ * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
+ *
+ * To the format expected by the backend:
+ *
+ * onnxModel(config_name).output
*
* @author lesters
*/
@@ -33,85 +49,92 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
if (context.rankProfile() == null) return feature;
if (context.rankProfile().getSearch() == null) return feature;
- return transformFeature(feature, context.rankProfile().getSearch());
+ return transformFeature(feature, context.rankProfile());
}
- public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
- if (!feature.getName().equals("onnxModel")) return feature;
+ public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) {
+ ImmutableSearch search = rankProfile.getSearch();
+ final String featureName = feature.getName();
+ if ( ! featureName.equals("onnxModel")) return feature;
Arguments arguments = feature.getArguments();
if (arguments.isEmpty())
- throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " +
- "onnx-model config or a ONNX file.");
- if (arguments.expressions().size() > 2)
- throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
-
- // Validation that the file actually exists is handled when the file is added to file distribution.
- // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator.
-
- String modelConfigName;
- OnnxModel onnxModel;
- if (arguments.expressions().get(0) instanceof ReferenceNode) {
- modelConfigName = arguments.expressions().get(0).toString();
- onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
- }
- } else if (arguments.expressions().get(0) instanceof ConstantNode) {
+ throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " +
+ "onnx-model config or an ONNX file.");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
+
+ // Check that the model configuration "onnx-model" exists. If not defined, it should have been added
+ // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find
+ // the actual ONNX file, which can happen if we are restarting or upgrading an application using an
+ // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store.
+
+ String modelConfigName = getModelConfigName(feature.reference());
+ OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
String path = asString(arguments.expressions().get(0));
- modelConfigName = asValidIdentifier(path);
- onnxModel = search.onnxModels().get(modelConfigName);
- if (onnxModel == null) {
- onnxModel = new OnnxModel(modelConfigName, path);
- search.onnxModels().add(onnxModel);
- }
- } else {
- throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
+ ModelName modelName = new ModelName(null, Path.fromString(path), true);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile);
+ FeatureArguments featureArguments = new FeatureArguments(arguments);
+ return convertedModel.expression(featureArguments, null);
}
- String output = null;
- if (feature.getOutput() != null) {
- output = feature.getOutput();
- if ( ! hasOutputMapping(onnxModel, output)) {
- onnxModel.addOutputNameMapping(output, output);
+ String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
+ String output = getModelOutput(feature.reference(), defaultOutput);
+ if (! onnxModel.getOutputMap().containsValue(output)) {
+ throw new IllegalArgumentException(featureName + " argument '" + output +
+ "' output not found in model '" + onnxModel.getFileName() + "'");
+ }
+ return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
+ }
+
+ public static String getModelConfigName(Reference reference) {
+ if (reference.arguments().size() > 0) {
+ ExpressionNode expr = reference.arguments().expressions().get(0);
+ if (expr instanceof ReferenceNode) { // refers to onnx-model config
+ return expr.toString();
}
- } else if (arguments.expressions().size() > 1) {
- String name = asString(arguments.expressions().get(1));
- output = asValidIdentifier(name);
- if ( ! hasOutputMapping(onnxModel, output)) {
- onnxModel.addOutputNameMapping(name, output);
+ if (expr instanceof ConstantNode) { // refers to an file path
+ return asValidIdentifier(expr);
}
}
+ return null;
+ }
- // Replace feature with name of config
- ExpressionNode argument = new ReferenceNode(modelConfigName);
- return new ReferenceNode("onnxModel", List.of(argument), output);
-
+ public static String getModelOutput(Reference reference, String defaultOutput) {
+ if (reference.output() != null) {
+ return reference.output();
+ } else if (reference.arguments().expressions().size() == 2) {
+ return asValidIdentifier(reference.arguments().expressions().get(1));
+ } else if (reference.arguments().expressions().size() > 2) {
+ return asValidIdentifier(reference.arguments().expressions().get(2));
+ }
+ return defaultOutput;
}
- private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
- return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
+ public static String stripQuotes(String s) {
+ if (isNotQuoteSign(s.codePointAt(0))) return s;
+ if (isNotQuoteSign(s.codePointAt(s.length() - 1)))
+ throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
+ return s.substring(1, s.length()-1);
}
- private static String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ public static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
}
- private static String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
+ private static String asValidIdentifier(ExpressionNode node) {
+ return asValidIdentifier(asString(node));
}
- private static boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
+ private static boolean isNotQuoteSign(int c) {
+ return c != '\'' && c != '"';
}
- private static String asValidIdentifier(String str) {
- return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ public static String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
new file mode 100644
index 00000000000..afba88c135d
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -0,0 +1,97 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+
+import java.util.Map;
+
+/**
+ * Processes ONNX ranking features of the form:
+ *
+ * onnx("files/model.onnx", "path/to/output:1")
+ *
+ * And generates an "onnx-model" configuration as if it was defined in the schema:
+ *
+ * onnx-model files_model_onnx {
+ * file: "files/model.onnx"
+ * }
+ *
+ * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be
+ * processed after this.
+ *
+ * @author lesters
+ */
+public class OnnxModelConfigGenerator extends Processor {
+
+ public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+ for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) {
+ if (profile.getFirstPhaseRanking() != null) {
+ process(profile.getFirstPhaseRanking().getRoot());
+ }
+ if (profile.getSecondPhaseRanking() != null) {
+ process(profile.getSecondPhaseRanking().getRoot());
+ }
+ for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
+ process(function.getValue().function().getBody().getRoot());
+ }
+ for (ReferenceNode feature : profile.getSummaryFeatures()) {
+ process(feature);
+ }
+ }
+ }
+
+ private void process(ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ process((ReferenceNode)node);
+ } else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode) node).children()) {
+ process(child);
+ }
+ }
+ }
+
+ private void process(ReferenceNode feature) {
+ if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) {
+ if (feature.getArguments().size() > 0) {
+ if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
+ ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
+ String path = OnnxModelTransformer.stripQuotes(node.sourceString());
+ String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
+
+ // Only add the configuration if the model can actually be found.
+ if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ path = ApplicationPackage.MODELS_DIR.append(path).toString();
+ if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ return;
+ }
+ }
+
+ OnnxModel onnxModel = search.onnxModels().get(modelConfigName);
+ if (onnxModel == null) {
+ onnxModel = new OnnxModel(modelConfigName, path);
+ search.onnxModels().add(onnxModel);
+ }
+ }
+ }
+ }
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
new file mode 100644
index 00000000000..bead2e7e7c9
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -0,0 +1,154 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.component.Version;
+import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.DeployLogger;
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.OnnxModel;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
+import com.yahoo.vespa.defaults.Defaults;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
+import onnx.Onnx;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Paths;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * Processes every "onnx-model" element in the schema. Parses the model file,
+ * adds missing input and output mappings (assigning default names), and
+ * adds tensor types to all model inputs and outputs.
+ *
+ * Must be processed before RankingExpressingTypeResolver.
+ *
+ * @author lesters
+ */
+public class OnnxModelTypeResolver extends Processor {
+
+ public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) {
+ super(search, deployLogger, rankProfileRegistry, queryProfiles);
+ }
+
+ @Override
+ public void process(boolean validate, boolean documentsOnly) {
+ if (documentsOnly) return;
+
+ for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
+ OnnxModel modelConfig = entry.getValue();
+ try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+
+ // Model inputs - if not defined, assumes a function is provided with a valid name
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
+ String onnxInputName = valueInfo.getName();
+ String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
+ modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
+ modelConfig.addInputType(onnxInputName, valueInfo.getType());
+ }
+
+ // Model outputs
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
+ String onnxOutputName = valueInfo.getName();
+ String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
+ modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
+ modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
+ }
+
+ // Set the first output as default
+ if ( ! model.getGraph().getOutputList().isEmpty()) {
+ modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
+ }
+
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+ }
+
+ static boolean modelFileExists(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (getFile(pathInApplicationPackage, app).exists()) {
+ return true;
+ }
+ if (getFileReference(pathInApplicationPackage, app).isPresent()) {
+ return true;
+ }
+ return false;
+ }
+
+ private InputStream openModelFile(Path path) throws FileNotFoundException {
+ ApplicationFile file;
+ Optional<FileReference> reference;
+ Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
+
+ if ((file = getFile(path)).exists()) {
+ return file.createInputStream();
+ }
+ if ((file = getFile(modelsPath)).exists()) {
+ return file.createInputStream();
+ }
+ if ((reference = getFileReference(path)).isPresent()) {
+ return openFromFileRepository(path, reference.get());
+ }
+ if ((reference = getFileReference(modelsPath)).isPresent()) {
+ return openFromFileRepository(modelsPath, reference.get());
+ }
+
+ throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
+ "in application package or file repository.");
+ }
+
+ private ApplicationFile getFile(Path path) {
+ return getFile(path, search.applicationPackage());
+ }
+
+ private static ApplicationFile getFile(Path path, ApplicationPackage app) {
+ return app.getFile(path);
+ }
+
+ private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
+ return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
+ }
+
+ public static String getFileRepositoryPath(Path path, String fileReference) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
+ return Paths.get(fileRefDir, fileReference, path.getName()).toString();
+ }
+
+ private Optional<FileReference> getFileReference(Path path) {
+ return getFileReference(path, search.applicationPackage());
+ }
+
+ private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
+ Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
+ if (fileRegistry.isPresent()) {
+ for (FileRegistry.Entry file : fileRegistry.get().export()) {
+ if (file.relativePath.equals(path.toString())) {
+ return Optional.of(file.reference);
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
+ if (app == null) return Optional.empty();
+ Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
+ return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
index e8594c2a87f..1a3ef9e54b4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java
@@ -74,6 +74,8 @@ public class Processing {
ReferenceFieldsProcessor::new,
FastAccessValidator::new,
ReservedFunctionNames::new,
+ OnnxModelConfigGenerator::new,
+ OnnxModelTypeResolver::new,
RankingExpressionTypeResolver::new,
// These should be last:
IndexingValidation::new,