aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-10-29 15:31:14 +0100
committerLester Solbakken <lesters@oath.com>2020-10-29 15:31:14 +0100
commit2425b3562e5a84e6caf44228712cfade4c8583a2 (patch)
tree6ea1dff3bd3ba0ea4accb91d5e30447eddaf030a
parentfbd8ca020a6d97882c554585d911889d1b9f69ea (diff)
Store generated model info for ZK
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java178
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java30
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java135
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java15
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java389
-rw-r--r--config-model/src/test/integration/onnx-model/schemas/test.sd (renamed from config-model/src/test/integration/onnx-model/searchdefinitions/test.sd)0
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java52
9 files changed, 500 insertions, 306 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index b153ff62e7d..ab42e4d821a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -95,7 +95,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
// are there other cases we would like to resolve globally?
}
- @Override
+ @Override
public TensorType getType(Reference reference) {
// computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 5e8b8579ee6..64338e24a8d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -3,21 +3,16 @@ package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
+import com.yahoo.vespa.model.ml.OnnxModelInfo;
import com.yahoo.vespa.model.utils.FileSender;
-import onnx.Onnx;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -32,14 +27,10 @@ public class OnnxModel {
private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private String defaultOutput = null;
+ private OnnxModelInfo modelInfo = null;
private Map<String, String> inputMap = new HashMap<>();
private Map<String, String> outputMap = new HashMap<>();
- private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
- private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
- private Map<String, TensorType> vespaTypes = new HashMap<>();
-
public OnnxModel(String name) {
this.name = name;
}
@@ -64,11 +55,6 @@ public class OnnxModel {
return pathType;
}
- public void setDefaultOutput(String onnxName) {
- Objects.requireNonNull(onnxName, "Name cannot be null");
- this.defaultOutput = onnxName;
- }
-
public void addInputNameMapping(String onnxName, String vespaName) {
addInputNameMapping(onnxName, vespaName, true);
}
@@ -93,16 +79,9 @@ public class OnnxModel {
}
}
- public void addInputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- inputTypes.put(onnxName, type);
- }
-
- public void addOutputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- outputTypes.put(onnxName, type);
+ public void setModelInfo(OnnxModelInfo modelInfo) {
+ Objects.requireNonNull(modelInfo, "Onnx model info cannot be null");
+ this.modelInfo = modelInfo;
}
/** Initiate sending of this constant to some services over file distribution */
@@ -123,7 +102,11 @@ public class OnnxModel {
public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
public String getDefaultOutput() {
- return defaultOutput;
+ return modelInfo != null ? modelInfo.getDefaultOutput() : "";
+ }
+
+ TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
+ return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
}
public void validate() {
@@ -140,145 +123,4 @@ public class OnnxModel {
return b.toString();
}
- /**
- * Return the tensor type for an ONNX model for the given context.
- * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
- * type depends on the input types for the given context (rank profile).
- */
- public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
- Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
- if (onnxOutputType == null) {
- throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
- }
- if (allDimensionSizesAreKnown(onnxOutputType)) {
- return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
- }
- return getTensorTypeWithUnknownDimensions(onnxOutputType, context);
- }
-
- private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) {
- return type.getTensorType().getShape().getDimList().stream().noneMatch(d ->
- (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1);
- }
-
- private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
- long unboundSize = 0;
- Map<String, Long> symbolicSizes = new HashMap<>();
-
- for (String onnxInputName : inputTypes.keySet()) {
- Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
- if (allDimensionSizesAreKnown(onnxType)) {
- continue;
- }
-
- Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
- if (vespaType.isEmpty()) {
- return TensorType.empty;
- }
-
- var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
- var vespaDimensions = vespaType.get().dimensions();
- if (vespaDimensions.size() != onnxDimensions.size()) {
- return TensorType.empty;
- }
-
- for (int i = 0; i < vespaDimensions.size(); ++i) {
- if (vespaDimensions.get(i).size().isEmpty()) {
- continue;
- }
- Long size = vespaDimensions.get(i).size().get();
-
- // Handle dimensions with size -1 - typically batch dimensions
- if (onnxDimensions.get(i).getDimValue() == -1) {
- if (unboundSize != 0 && unboundSize != size) {
- throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " +
- "for type '" + onnxOutputType + "' in ONNX model '" + name + "'");
- }
- unboundSize = size;
-
- // Handle dimensions with symbolic names
- } else if (onnxDimensions.get(i).hasDimParam()) {
- String symbolicName = onnxDimensions.get(i).getDimParam();
- if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
- throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" +
- symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
- }
- symbolicSizes.put(symbolicName, size);
- }
- }
- }
- return typeFrom(onnxOutputType, symbolicSizes, unboundSize);
- }
-
- private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
- String source = inputMap.get(onnxInputName);
- if (source != null) {
- // Source is either a simple reference (query/attribute/constant)...
- Optional<Reference> reference = Reference.simple(source);
- if (reference.isPresent()) {
- return Optional.of(context.getType(reference.get()));
- }
- // ... or a function
- ExpressionFunction func = context.getFunction(source);
- if (func != null) {
- return Optional.of(func.getBody().type(context));
- }
- }
- return Optional.empty(); // if this context does not contain this input
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type) {
- return typeFrom(type, null, 0);
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) {
- String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- long onnxDimensionSize = onnxDimension.getDimValue();
- if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
- onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
- }
- if (onnxDimensionSize == 0 && symbolicSizes != null) {
- // This is for the case where all symbolic dimensions have
- // different names, but can be resolved to a single dimension size.
- Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
- if (unknownSizes.size() == 1) {
- onnxDimensionSize = unknownSizes.iterator().next();
- }
- }
- if (onnxDimensionSize < 0) {
- onnxDimensionSize = unboundSize;
- }
- if (onnxDimensionSize <= 0) {
- throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
- "ONNX type: " + type + " to Vespa tensor type.");
- }
- builder.indexed(dimensionName, onnxDimensionSize);
- }
- return builder.build();
- }
-
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT8: return TensorType.Value.FLOAT;
- case INT16: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.FLOAT;
- case UINT8: return TensorType.Value.FLOAT;
- case UINT16: return TensorType.Value.FLOAT;
- case UINT32: return TensorType.Value.FLOAT;
- case UINT64: return TensorType.Value.FLOAT;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
- }
- }
-
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 96c043bdb34..9b129eb66ce 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -831,18 +831,44 @@ public class RankProfile implements Cloneable {
String modelName = entry.getKey();
OnnxModel model = entry.getValue();
Arguments args = new Arguments(new ReferenceNode(modelName));
+ Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context);
- TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
+ TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes);
context.setType(new Reference("onnxModel", args, null), defaultOutputType);
for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
- TensorType type = model.getTensorType(mapping.getKey(), context);
+ TensorType type = model.getTensorType(mapping.getKey(), inputTypes);
context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
}
}
return context;
}
+ private Map<String, TensorType> resolveOnnxInputTypes(OnnxModel model, MapEvaluationTypeContext context) {
+ Map<String, TensorType> inputTypes = new HashMap<>();
+ for (String onnxInputName : model.getInputMap().keySet()) {
+ resolveOnnxInputType(onnxInputName, model, context).ifPresent(type -> inputTypes.put(onnxInputName, type));
+ }
+ return inputTypes;
+ }
+
+ private Optional<TensorType> resolveOnnxInputType(String onnxInputName, OnnxModel model, MapEvaluationTypeContext context) {
+ String source = model.getInputMap().get(onnxInputName);
+ if (source != null) {
+ // Source is either a simple reference (query/attribute/constant)...
+ Optional<Reference> reference = Reference.simple(source);
+ if (reference.isPresent()) {
+ return Optional.of(context.getType(reference.get()));
+ }
+ // ... or a function
+ ExpressionFunction func = context.getFunction(source);
+ if (func != null) {
+ return Optional.of(func.getBody().type(context));
+ }
+ }
+ return Optional.empty(); // if this context does not contain this input
+ }
+
private void addAttributeFeatureTypes(ImmutableSDField field, Map<Reference, TensorType> featureTypes) {
Attribute attribute = field.getAttribute();
field.getAttributes().forEach((k, a) -> {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
index afba88c135d..70ad3b255e3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -14,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.vespa.model.container.search.QueryProfiles;
+import com.yahoo.vespa.model.ml.OnnxModelInfo;
import java.util.Map;
@@ -77,9 +78,9 @@ public class OnnxModelConfigGenerator extends Processor {
String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
// Only add the configuration if the model can actually be found.
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ if ( ! OnnxModelInfo.modelExists(path, search.applicationPackage())) {
path = ApplicationPackage.MODELS_DIR.append(path).toString();
- if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) {
+ if ( ! OnnxModelInfo.modelExists(path, search.applicationPackage())) {
return;
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
index bead2e7e7c9..8e92b1980ac 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -2,35 +2,18 @@
package com.yahoo.searchdefinition.processing;
-import com.yahoo.cloud.config.ConfigserverConfig;
-import com.yahoo.component.Version;
-import com.yahoo.config.FileReference;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
-import com.yahoo.config.application.api.FileRegistry;
-import com.yahoo.path.Path;
import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
-import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
-import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.model.container.search.QueryProfiles;
-import onnx.Onnx;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.file.Paths;
-import java.util.Map;
-import java.util.Optional;
+import com.yahoo.vespa.model.ml.OnnxModelInfo;
/**
- * Processes every "onnx-model" element in the schema. Parses the model file,
- * adds missing input and output mappings (assigning default names), and
- * adds tensor types to all model inputs and outputs.
+ * Processes every "onnx-model" element in the schema. Associates model type
+ * information by retrieving from either the ONNX model file directly or from
+ * preprocessed information in ZK. Adds missing input and output mappings
+ * (assigning default names).
*
* Must be processed before RankingExpressingTypeResolver.
*
@@ -46,109 +29,19 @@ public class OnnxModelTypeResolver extends Processor {
public void process(boolean validate, boolean documentsOnly) {
if (documentsOnly) return;
- for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) {
- OnnxModel modelConfig = entry.getValue();
- try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) {
- Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
-
- // Model inputs - if not defined, assumes a function is provided with a valid name
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
- String onnxInputName = valueInfo.getName();
- String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName);
- modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false);
- modelConfig.addInputType(onnxInputName, valueInfo.getType());
- }
-
- // Model outputs
- for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
- String onnxOutputName = valueInfo.getName();
- String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName);
- modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false);
- modelConfig.addOutputType(onnxOutputName, valueInfo.getType());
- }
-
- // Set the first output as default
- if ( ! model.getGraph().getOutputList().isEmpty()) {
- modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName());
- }
+ for (OnnxModel onnxModel : search.onnxModels().asMap().values()) {
+ OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), search.applicationPackage());
- } catch (IOException e) {
- throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ // Add any missing input and output fields that were not specified in the onnx-model configuration
+ for (String onnxName : onnxModelInfo.getInputs()) {
+ onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
}
- }
- }
-
- static boolean modelFileExists(String path, ApplicationPackage app) {
- Path pathInApplicationPackage = Path.fromString(path);
- if (getFile(pathInApplicationPackage, app).exists()) {
- return true;
- }
- if (getFileReference(pathInApplicationPackage, app).isPresent()) {
- return true;
- }
- return false;
- }
-
- private InputStream openModelFile(Path path) throws FileNotFoundException {
- ApplicationFile file;
- Optional<FileReference> reference;
- Path modelsPath = ApplicationPackage.MODELS_DIR.append(path);
-
- if ((file = getFile(path)).exists()) {
- return file.createInputStream();
- }
- if ((file = getFile(modelsPath)).exists()) {
- return file.createInputStream();
- }
- if ((reference = getFileReference(path)).isPresent()) {
- return openFromFileRepository(path, reference.get());
- }
- if ((reference = getFileReference(modelsPath)).isPresent()) {
- return openFromFileRepository(modelsPath, reference.get());
- }
-
- throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " +
- "in application package or file repository.");
- }
-
- private ApplicationFile getFile(Path path) {
- return getFile(path, search.applicationPackage());
- }
-
- private static ApplicationFile getFile(Path path, ApplicationPackage app) {
- return app.getFile(path);
- }
-
- private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException {
- return new FileInputStream(new File(getFileRepositoryPath(path, reference.value())));
- }
-
- public static String getFileRepositoryPath(Path path, String fileReference) {
- ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
- String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
- return Paths.get(fileRefDir, fileReference, path.getName()).toString();
- }
-
- private Optional<FileReference> getFileReference(Path path) {
- return getFileReference(path, search.applicationPackage());
- }
-
- private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) {
- Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app);
- if (fileRegistry.isPresent()) {
- for (FileRegistry.Entry file : fileRegistry.get().export()) {
- if (file.relativePath.equals(path.toString())) {
- return Optional.of(file.reference);
- }
+ for (String onnxName : onnxModelInfo.getOutputs()) {
+ onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
}
- }
- return Optional.empty();
- }
- private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) {
- if (app == null) return Optional.empty();
- Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo);
- return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get()));
+ onnxModel.setModelInfo(onnxModelInfo);
+ }
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index d5c5183b01f..00797876395 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -1,13 +1,14 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation;
+import com.yahoo.cloud.config.ConfigserverConfig;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.IOUtils;
import com.yahoo.log.InvalidLogFormatException;
-import java.util.logging.Level;
import com.yahoo.log.LogMessage;
+import com.yahoo.path.Path;
import com.yahoo.searchdefinition.OnnxModel;
-import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver;
+import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.yolean.Exceptions;
import com.yahoo.system.ProcessExecuter;
import com.yahoo.text.StringUtilities;
@@ -30,9 +31,11 @@ import com.yahoo.vespa.model.search.SearchCluster;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
+import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.logging.Logger;
+import java.util.logging.Level;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -152,7 +155,7 @@ public class RankSetupValidator extends Validator {
if (models.values().size() > 0) {
List<String> config = new ArrayList<>(models.values().size() * 2);
for (OnnxModel model : models.values()) {
- String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference());
+ String modelPath = getFileRepositoryPath(model.getFilePath(), model.getFileReference());
config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference()));
config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath));
}
@@ -160,6 +163,12 @@ public class RankSetupValidator extends Validator {
}
}
+ public static String getFileRepositoryPath(Path path, String fileReference) {
+ ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults
+ String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir());
+ return Paths.get(fileRefDir, fileReference, path.getName()).toString();
+ }
+
private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException {
IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
new file mode 100644
index 00000000000..7526a8a8595
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -0,0 +1,389 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import com.fasterxml.jackson.core.JsonEncoding;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.TensorType;
+import onnx.Onnx;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Model information (input and output types) for an ONNX model.
+ * This encapsulates the difference between reading ONNX model information
+ * - from a file application package, where we can read the ONNX model directly
+ * - from a ZK application package, where the file is unavailable and models are read from
+ * generated files stored in file distribution or ZooKeeper.
+ *
+ * @author lesters
+ */
+public class OnnxModelInfo {
+
+ private final String defaultOutput;
+ private final Map<String, OnnxTypeInfo> inputs;
+ private final Map<String, OnnxTypeInfo> outputs;
+ private final Map<String, TensorType> vespaTypes = new HashMap<>();
+
+ private OnnxModelInfo(Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ this.inputs = Collections.unmodifiableMap(inputs);
+ this.outputs = Collections.unmodifiableMap(outputs);
+ this.defaultOutput = defaultOutput;
+ }
+
+ public Set<String> getInputs() {
+ return inputs.keySet();
+ }
+
+ public Set<String> getOutputs() {
+ return outputs.keySet();
+ }
+
+ public String getDefaultOutput() {
+ return defaultOutput;
+ }
+
+ /**
+ * Return the tensor type for an ONNX model for the given context.
+ * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
+ * type depends on the input types for the given context (rank profile).
+ */
+ public TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
+ OnnxTypeInfo onnxTypeInfo = outputs.get(onnxName);
+ if (onnxTypeInfo == null) {
+ throw new IllegalArgumentException("Could not find type for output '" + onnxName + "'");
+ }
+ if (onnxTypeInfo.containsUnknownDimensionSizes()) {
+ Set<Long> unboundSizes = new HashSet<>();
+ Map<String, Long> symbolicSizes = new HashMap<>();
+ resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes);
+ return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes);
+ }
+ return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType());
+ }
+
+ private void resolveUnknownDimensionSizes(Map<String, TensorType> inputTypes,
+ Map<String, Long> symbolicSizes,
+ Set<Long> unboundSizes)
+ {
+ for (Map.Entry<String, OnnxTypeInfo> input : inputs.entrySet()) {
+ String onnxName = input.getKey();
+ OnnxTypeInfo onnxType = input.getValue();
+ TensorType vespaType = inputTypes.get(onnxName);
+ if (vespaType == null || vespaType.dimensions().size() != onnxType.dimensions().size()) {
+ continue;
+ }
+
+ for (int i = 0; i < vespaType.dimensions().size(); ++i) {
+ if (vespaType.dimensions().get(i).size().isEmpty()) {
+ continue;
+ }
+ Long size = vespaType.dimensions().get(i).size().get();
+
+ // Handle dimensions with size -1 - typically batch dimensions
+ if (onnxType.dimensions().get(i).getSize() == -1) {
+ unboundSizes.add(size);
+ if (unboundSizes.size() > 1) {
+ throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " +
+ "for type '" + onnxType + "'");
+ }
+
+ // Handle dimensions with symbolic names
+ } else if (onnxType.dimensions().get(i).hasSymbolicName()) {
+ String symbolicName = onnxType.dimensions().get(i).getSymbolicName();
+ if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
+ throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" +
+ symbolicName + "' for input '" + onnxName + "'");
+ }
+ symbolicSizes.put(symbolicName, size);
+ }
+ }
+ }
+ }
+
+ static public OnnxModelInfo load(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (app.getFile(pathInApplicationPackage).exists()) {
+ return loadFromFile(pathInApplicationPackage, app);
+ }
+ if (app.getFile(generatedModelInfoPath(pathInApplicationPackage)).exists()) {
+ return loadFromGeneratedInfo(pathInApplicationPackage, app);
+ }
+ throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file");
+ }
+
+ static public boolean modelExists(String path, ApplicationPackage app) {
+ Path pathInApplicationPackage = Path.fromString(path);
+ if (app.getFile(pathInApplicationPackage).exists()) {
+ return true;
+ }
+ if (app.getFile(generatedModelInfoPath(Path.fromString(path))).exists()) {
+ return true;
+ }
+ return false;
+ }
+
+ static private OnnxModelInfo loadFromFile(Path path, ApplicationPackage app) {
+ try (InputStream inputStream = app.getFile(path).createInputStream()) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ String json = onnxModelToJson(model);
+ storeGeneratedInfo(json, path, app);
+ return jsonToModelInfo(json);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+
+ static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) {
+ try {
+ String json = readGeneratedInfo(path, app);
+ return jsonToModelInfo(json);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Unable to parse ONNX model", e);
+ }
+ }
+
+ static private String readGeneratedInfo(Path path, ApplicationPackage app) throws IOException {
+ ApplicationFile file = app.getFile(generatedModelInfoPath(path));
+ return IOUtils.readAll(file.createReader());
+ }
+
+ static private void storeGeneratedInfo(String json, Path path, ApplicationPackage app) throws IOException {
+ IOUtils.writeFile(app.getFileReference(generatedModelInfoPath(path)), json, false);
+ }
+
+ static private Path generatedModelInfoPath(Path path) {
+ String fileName = asValidIdentifier(path.getRelative()) + ".modelinfo.json";
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
+ }
+
+ static private String onnxModelToJson(Onnx.ModelProto model) throws IOException {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
+ g.writeStartObject();
+
+ g.writeArrayFieldStart("inputs");
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
+ onnxTypeToJson(g, valueInfo);
+ }
+ g.writeEndArray();
+
+ g.writeArrayFieldStart("outputs");
+ for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
+ onnxTypeToJson(g, valueInfo);
+ }
+ g.writeEndArray();
+
+ g.writeEndObject();
+ g.close();
+ return out.toString();
+ }
+
+ static public OnnxModelInfo jsonToModelInfo(String json) throws IOException {
+ ObjectMapper m = new ObjectMapper();
+ JsonNode root = m.readTree(json);
+ Map<String, OnnxTypeInfo> inputs = new HashMap<>();
+ Map<String, OnnxTypeInfo> outputs = new HashMap<>();
+ String defaultOutput = "";
+
+ for (JsonNode input : root.get("inputs")) {
+ inputs.put(input.get("name").textValue(), jsonToTypeInfo(input));
+ }
+ for (JsonNode output : root.get("outputs")) {
+ outputs.put(output.get("name").textValue(), jsonToTypeInfo(output));
+ }
+ if (root.get("outputs").has(0)) {
+ defaultOutput = root.get("outputs").get(0).get("name").textValue();
+ }
+ return new OnnxModelInfo(inputs, outputs, defaultOutput);
+ }
+
+ static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {
+ g.writeStartObject();
+ g.writeStringField("name", valueInfo.getName());
+ g.writeStringField("type", onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType()));
+ g.writeArrayFieldStart("dim");
+ for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) {
+ g.writeStartObject();
+ if (dim.hasDimParam()) {
+ g.writeStringField("type", "param");
+ g.writeStringField("size", dim.getDimParam());
+ } else {
+ g.writeStringField("type", "value");
+ g.writeNumberField("size", dim.getDimValue());
+ }
+ g.writeEndObject();
+ }
+ g.writeEndArray();
+ g.writeEndObject();
+ }
+
+ static private OnnxTypeInfo jsonToTypeInfo(JsonNode node) {
+ TensorType.Value valueType = stringToValueType(node.get("type").textValue());
+ OnnxTypeInfo type = new OnnxTypeInfo(valueType);
+ for (JsonNode dim : node.get("dim")) {
+ if (dim.get("type").textValue().equals("param")) {
+ type.addDimension(dim.get("size").textValue());
+ } else {
+ type.addDimension(dim.get("size").longValue());
+ }
+ }
+ return type;
+ }
+
+ private static String onnxValueTypeToString(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return "float";
+ case DOUBLE: return "double";
+ // Imperfect conversion, for now:
+ case BOOL: return "float";
+ case INT8: return "float";
+ case INT16: return "float";
+ case INT32: return "float";
+ case INT64: return "float";
+ case UINT8: return "float";
+ case UINT16: return "float";
+ case UINT32: return "float";
+ case UINT64: return "float";
+ default:
+ throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
+ private static TensorType.Value stringToValueType(String type) {
+ switch (type) {
+ case "float": return TensorType.Value.FLOAT;
+ case "double": return TensorType.Value.DOUBLE;
+ default:
+ throw new IllegalArgumentException("Unknown tensor value type: " + type);
+ }
+ }
+
+ public static String asValidIdentifier(String str) {
+ return str.replaceAll("[^\\w\\d\\$@_]", "_");
+ }
+
+
+ private static class OnnxTypeInfo {
+ private final TensorType.Value valueType;
+ private final List<OnnxDimensionInfo> dimensions = new ArrayList<>();
+
+ OnnxTypeInfo(TensorType.Value valueType) {
+ this.valueType = valueType;
+ }
+
+ void addDimension(long value) {
+ dimensions.add(new OnnxDimensionInfo(value));
+ }
+
+ void addDimension(String param) {
+ dimensions.add(new OnnxDimensionInfo(param));
+ }
+
+ boolean containsUnknownDimensionSizes() {
+ return dimensions.stream().anyMatch(OnnxDimensionInfo::unknownDimensionSize);
+ }
+
+ TensorType.Value valueType() {
+ return valueType;
+ }
+
+ List<OnnxDimensionInfo> dimensions() {
+ return dimensions;
+ }
+
+ TensorType toVespaTensorType() {
+ return toVespaTensorType(null, null);
+ }
+
+ TensorType toVespaTensorType(Map<String, Long> symbolicSizes, Set<Long> unboundSizes) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
+ TensorType.Builder builder = new TensorType.Builder(valueType);
+ for (int i = 0; i < dimensions.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ OnnxDimensionInfo onnxDimension = dimensions.get(i);
+ long onnxDimensionSize = onnxDimension.getSize();
+ if (onnxDimension.hasSymbolicName() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getSymbolicName())) {
+ onnxDimensionSize = symbolicSizes.get(onnxDimension.getSymbolicName());
+ }
+ if (onnxDimensionSize == 0 && symbolicSizes != null) {
+ // This is for the case where all symbolic dimensions have
+ // different names, but can be resolved to a single dimension size.
+ Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
+ if (unknownSizes.size() == 1) {
+ onnxDimensionSize = unknownSizes.iterator().next();
+ }
+ }
+ if (onnxDimensionSize < 0 && unboundSizes != null && unboundSizes.size() > 0) {
+ onnxDimensionSize = unboundSizes.iterator().next();
+ }
+ if (onnxDimensionSize <= 0) {
+ return TensorType.empty; // Unable to determine type - probably out of context
+ }
+ builder.indexed(dimensionName, onnxDimensionSize);
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String toString() {
+ return "(" + valueType.id() + ")" +
+ "[" + dimensions.stream().map(OnnxDimensionInfo::toString).collect(Collectors.joining(",")) + "]";
+ }
+
+ }
+
+ private static class OnnxDimensionInfo {
+ private final long size;
+ private final String symbolicName;
+
+ OnnxDimensionInfo(long size) {
+ this.size = size;
+ this.symbolicName = null;
+ }
+
+ OnnxDimensionInfo(String symbolicName) {
+ this.size = 0;
+ this.symbolicName = symbolicName;
+ }
+
+ long getSize() {
+ return size;
+ }
+
+ String getSymbolicName() {
+ return symbolicName;
+ }
+
+ boolean hasSymbolicName() {
+ return symbolicName != null;
+ }
+
+ boolean unknownDimensionSize() {
+ return hasSymbolicName() || size <= 0;
+ }
+
+ @Override
+ public String toString() {
+ return hasSymbolicName() ? "\"" + symbolicName + "\"" : Long.toString(size);
+ }
+ }
+
+}
diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd
index a87222e77ee..a87222e77ee 100644
--- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd
+++ b/config-model/src/test/integration/onnx-model/schemas/test.sd
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index 4eb8681c374..f8a379b4027 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -1,27 +1,61 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.search.DocumentDatabase;
import com.yahoo.vespa.model.search.IndexedSearchCluster;
-import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg;
+import org.junit.After;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class RankingExpressionWithOnnxModelTestCase {
+ private final Path applicationDir = Path.fromString("src/test/integration/onnx-model/");
+
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
@Test
- public void testOnnxModelFeature() {
- VespaModel model = new VespaModelCreatorWithFilePkg("src/test/integration/onnx-model").create();
- DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0);
- assertTransformedFeature(db);
- assertGeneratedConfig(db);
+ public void testOnnxModelFeature() throws Exception {
+ VespaModel model = loadModel(applicationDir);
+ assertTransformedFeature(model);
+ assertGeneratedConfig(model);
+
+ Path storedApplicationDir = applicationDir.append("copy");
+ try {
+ storedApplicationDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedApplicationDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append("schemas").toFile(), storedApplicationDir.append("schemas").toFile());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+
+ VespaModel storedModel = loadModel(storedApplicationDir);
+ assertTransformedFeature(storedModel);
+ assertGeneratedConfig(storedModel);
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDir.toFile());
+ }
}
- private void assertGeneratedConfig(DocumentDatabase db) {
+ private VespaModel loadModel(Path path) throws Exception {
+ FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile());
+ DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).build();
+ return new VespaModel(state);
+ }
+
+ private void assertGeneratedConfig(VespaModel model) {
+ DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0);
OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
((OnnxModelsConfig.Producer) db).getConfig(builder);
OnnxModelsConfig config = new OnnxModelsConfig(builder);
@@ -72,10 +106,10 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals(1, config.model(4).output().size());
assertEquals("rankingExpression(my_function)", config.model(4).input(0).source());
-
}
- private void assertTransformedFeature(DocumentDatabase db) {
+ private void assertTransformedFeature(VespaModel model) {
+ DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0);
RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder();
((RankProfilesConfig.Producer) db).getConfig(builder);
RankProfilesConfig config = new RankProfilesConfig(builder);