diff options
author | Lester Solbakken <lesters@oath.com> | 2020-10-29 15:31:14 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-10-29 15:31:14 +0100 |
commit | 2425b3562e5a84e6caf44228712cfade4c8583a2 (patch) | |
tree | 6ea1dff3bd3ba0ea4accb91d5e30447eddaf030a | |
parent | fbd8ca020a6d97882c554585d911889d1b9f69ea (diff) |
Store generated model info for ZK
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java | 2 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | 178 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java | 30 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java | 5 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java | 135 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java | 15 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java | 389 | ||||
-rw-r--r-- | config-model/src/test/integration/onnx-model/schemas/test.sd (renamed from config-model/src/test/integration/onnx-model/searchdefinitions/test.sd) | 0 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java | 52 |
9 files changed, 500 insertions, 306 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index b153ff62e7d..ab42e4d821a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -95,7 +95,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement // are there other cases we would like to resolve globally? } - @Override + @Override public TensorType getType(Reference reference) { // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 5e8b8579ee6..64338e24a8d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -3,21 +3,16 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.ml.OnnxModelInfo; import com.yahoo.vespa.model.utils.FileSender; -import onnx.Onnx; import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; import java.util.Objects; -import java.util.Optional; -import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -32,14 +27,10 @@ public class OnnxModel { private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private String defaultOutput = null; + private OnnxModelInfo modelInfo = null; private Map<String, String> inputMap = new HashMap<>(); private Map<String, String> outputMap = new HashMap<>(); - private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>(); - private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>(); - private Map<String, TensorType> vespaTypes = new HashMap<>(); - public OnnxModel(String name) { this.name = name; } @@ -64,11 +55,6 @@ public class OnnxModel { return pathType; } - public void setDefaultOutput(String onnxName) { - Objects.requireNonNull(onnxName, "Name cannot be null"); - this.defaultOutput = onnxName; - } - public void addInputNameMapping(String onnxName, String vespaName) { addInputNameMapping(onnxName, vespaName, true); } @@ -93,16 +79,9 @@ public class OnnxModel { } } - public void addInputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - inputTypes.put(onnxName, type); - } - - public void addOutputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - outputTypes.put(onnxName, type); + public void setModelInfo(OnnxModelInfo modelInfo) { + Objects.requireNonNull(modelInfo, "Onnx model info cannot be null"); + this.modelInfo = modelInfo; } /** Initiate sending of this constant to some services over file distribution */ @@ -123,7 +102,11 @@ public class OnnxModel { public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } public String getDefaultOutput() { - return defaultOutput; + return modelInfo != null ? modelInfo.getDefaultOutput() : ""; + } + + TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { + return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; } public void validate() { @@ -140,145 +123,4 @@ public class OnnxModel { return b.toString(); } - /** - * Return the tensor type for an ONNX model for the given context. - * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output - * type depends on the input types for the given context (rank profile). - */ - public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { - Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); - if (onnxOutputType == null) { - throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); - } - if (allDimensionSizesAreKnown(onnxOutputType)) { - return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); - } - return getTensorTypeWithUnknownDimensions(onnxOutputType, context); - } - - private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) { - return type.getTensorType().getShape().getDimList().stream().noneMatch(d -> - (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1); - } - - private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { - long unboundSize = 0; - Map<String, Long> symbolicSizes = new HashMap<>(); - - for (String onnxInputName : inputTypes.keySet()) { - Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); - if (allDimensionSizesAreKnown(onnxType)) { - continue; - } - - Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); - if (vespaType.isEmpty()) { - return TensorType.empty; - } - - var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); - var vespaDimensions = vespaType.get().dimensions(); - if (vespaDimensions.size() != onnxDimensions.size()) { - return TensorType.empty; - } - - for (int i = 0; i < vespaDimensions.size(); ++i) { - if (vespaDimensions.get(i).size().isEmpty()) { - continue; - } - Long size = vespaDimensions.get(i).size().get(); - - // Handle dimensions with size -1 - typically batch dimensions - if (onnxDimensions.get(i).getDimValue() == -1) { - if (unboundSize != 0 && unboundSize != size) { - throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " + - "for type '" + onnxOutputType + "' in ONNX model '" + name + "'"); - } - unboundSize = size; - - // Handle dimensions with symbolic names - } else if (onnxDimensions.get(i).hasDimParam()) { - String symbolicName = onnxDimensions.get(i).getDimParam(); - if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { - throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + - symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); - } - symbolicSizes.put(symbolicName, size); - } - } - } - return typeFrom(onnxOutputType, symbolicSizes, unboundSize); - } - - private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { - String source = inputMap.get(onnxInputName); - if (source != null) { - // Source is either a simple reference (query/attribute/constant)... - Optional<Reference> reference = Reference.simple(source); - if (reference.isPresent()) { - return Optional.of(context.getType(reference.get())); - } - // ... or a function - ExpressionFunction func = context.getFunction(source); - if (func != null) { - return Optional.of(func.getBody().type(context)); - } - } - return Optional.empty(); // if this context does not contain this input - } - - private static TensorType typeFrom(Onnx.TypeProto type) { - return typeFrom(type, null, 0); - } - - private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) { - String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... - Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - long onnxDimensionSize = onnxDimension.getDimValue(); - if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { - onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); - } - if (onnxDimensionSize == 0 && symbolicSizes != null) { - // This is for the case where all symbolic dimensions have - // different names, but can be resolved to a single dimension size. - Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); - if (unknownSizes.size() == 1) { - onnxDimensionSize = unknownSizes.iterator().next(); - } - } - if (onnxDimensionSize < 0) { - onnxDimensionSize = unboundSize; - } - if (onnxDimensionSize <= 0) { - throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + - "ONNX type: " + type + " to Vespa tensor type."); - } - builder.indexed(dimensionName, onnxDimensionSize); - } - return builder.build(); - } - - private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT8: return TensorType.Value.FLOAT; - case INT16: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.FLOAT; - case UINT8: return TensorType.Value.FLOAT; - case UINT16: return TensorType.Value.FLOAT; - case UINT32: return TensorType.Value.FLOAT; - case UINT64: return TensorType.Value.FLOAT; - default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); - } - } - } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 96c043bdb34..9b129eb66ce 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -831,18 +831,44 @@ public class RankProfile implements Cloneable { String modelName = entry.getKey(); OnnxModel model = entry.getValue(); Arguments args = new Arguments(new ReferenceNode(modelName)); + Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context); - TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context); + TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes); context.setType(new Reference("onnxModel", args, null), defaultOutputType); for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { - TensorType type = model.getTensorType(mapping.getKey(), context); + TensorType type = model.getTensorType(mapping.getKey(), inputTypes); context.setType(new Reference("onnxModel", args, mapping.getValue()), type); } } return context; } + private Map<String, TensorType> resolveOnnxInputTypes(OnnxModel model, MapEvaluationTypeContext context) { + Map<String, TensorType> inputTypes = new HashMap<>(); + for (String onnxInputName : model.getInputMap().keySet()) { + resolveOnnxInputType(onnxInputName, model, context).ifPresent(type -> inputTypes.put(onnxInputName, type)); + } + return inputTypes; + } + + private Optional<TensorType> resolveOnnxInputType(String onnxInputName, OnnxModel model, MapEvaluationTypeContext context) { + String source = model.getInputMap().get(onnxInputName); + if (source != null) { + // Source is either a simple reference (query/attribute/constant)... + Optional<Reference> reference = Reference.simple(source); + if (reference.isPresent()) { + return Optional.of(context.getType(reference.get())); + } + // ... or a function + ExpressionFunction func = context.getFunction(source); + if (func != null) { + return Optional.of(func.getBody().type(context)); + } + } + return Optional.empty(); // if this context does not contain this input + } + private void addAttributeFeatureTypes(ImmutableSDField field, Map<Reference, TensorType> featureTypes) { Attribute attribute = field.getAttribute(); field.getAttributes().forEach((k, a) -> { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java index afba88c135d..70ad3b255e3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -14,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.vespa.model.container.search.QueryProfiles; +import com.yahoo.vespa.model.ml.OnnxModelInfo; import java.util.Map; @@ -77,9 +78,9 @@ public class OnnxModelConfigGenerator extends Processor { String modelConfigName = OnnxModelTransformer.asValidIdentifier(path); // Only add the configuration if the model can actually be found. - if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + if ( ! OnnxModelInfo.modelExists(path, search.applicationPackage())) { path = ApplicationPackage.MODELS_DIR.append(path).toString(); - if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + if ( ! OnnxModelInfo.modelExists(path, search.applicationPackage())) { return; } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java index bead2e7e7c9..8e92b1980ac 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -2,35 +2,18 @@ package com.yahoo.searchdefinition.processing; -import com.yahoo.cloud.config.ConfigserverConfig; -import com.yahoo.component.Version; -import com.yahoo.config.FileReference; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; -import com.yahoo.config.application.api.FileRegistry; -import com.yahoo.path.Path; import com.yahoo.searchdefinition.OnnxModel; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; -import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.model.container.search.QueryProfiles; -import onnx.Onnx; - -import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Paths; -import java.util.Map; -import java.util.Optional; +import com.yahoo.vespa.model.ml.OnnxModelInfo; /** - * Processes every "onnx-model" element in the schema. Parses the model file, - * adds missing input and output mappings (assigning default names), and - * adds tensor types to all model inputs and outputs. + * Processes every "onnx-model" element in the schema. Associates model type + * information by retrieving from either the ONNX model file directly or from + * preprocessed information in ZK. Adds missing input and output mappings + * (assigning default names). * * Must be processed before RankingExpressingTypeResolver. * @@ -46,109 +29,19 @@ public class OnnxModelTypeResolver extends Processor { public void process(boolean validate, boolean documentsOnly) { if (documentsOnly) return; - for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) { - OnnxModel modelConfig = entry.getValue(); - try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) { - Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - - // Model inputs - if not defined, assumes a function is provided with a valid name - for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { - String onnxInputName = valueInfo.getName(); - String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName); - modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false); - modelConfig.addInputType(onnxInputName, valueInfo.getType()); - } - - // Model outputs - for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { - String onnxOutputName = valueInfo.getName(); - String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName); - modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false); - modelConfig.addOutputType(onnxOutputName, valueInfo.getType()); - } - - // Set the first output as default - if ( ! model.getGraph().getOutputList().isEmpty()) { - modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName()); - } + for (OnnxModel onnxModel : search.onnxModels().asMap().values()) { + OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), search.applicationPackage()); - } catch (IOException e) { - throw new IllegalArgumentException("Unable to parse ONNX model", e); + // Add any missing input and output fields that were not specified in the onnx-model configuration + for (String onnxName : onnxModelInfo.getInputs()) { + onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); } - } - } - - static boolean modelFileExists(String path, ApplicationPackage app) { - Path pathInApplicationPackage = Path.fromString(path); - if (getFile(pathInApplicationPackage, app).exists()) { - return true; - } - if (getFileReference(pathInApplicationPackage, app).isPresent()) { - return true; - } - return false; - } - - private InputStream openModelFile(Path path) throws FileNotFoundException { - ApplicationFile file; - Optional<FileReference> reference; - Path modelsPath = ApplicationPackage.MODELS_DIR.append(path); - - if ((file = getFile(path)).exists()) { - return file.createInputStream(); - } - if ((file = getFile(modelsPath)).exists()) { - return file.createInputStream(); - } - if ((reference = getFileReference(path)).isPresent()) { - return openFromFileRepository(path, reference.get()); - } - if ((reference = getFileReference(modelsPath)).isPresent()) { - return openFromFileRepository(modelsPath, reference.get()); - } - - throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " + - "in application package or file repository."); - } - - private ApplicationFile getFile(Path path) { - return getFile(path, search.applicationPackage()); - } - - private static ApplicationFile getFile(Path path, ApplicationPackage app) { - return app.getFile(path); - } - - private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException { - return new FileInputStream(new File(getFileRepositoryPath(path, reference.value()))); - } - - public static String getFileRepositoryPath(Path path, String fileReference) { - ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults - String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); - return Paths.get(fileRefDir, fileReference, path.getName()).toString(); - } - - private Optional<FileReference> getFileReference(Path path) { - return getFileReference(path, search.applicationPackage()); - } - - private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) { - Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app); - if (fileRegistry.isPresent()) { - for (FileRegistry.Entry file : fileRegistry.get().export()) { - if (file.relativePath.equals(path.toString())) { - return Optional.of(file.reference); - } + for (String onnxName : onnxModelInfo.getOutputs()) { + onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); } - } - return Optional.empty(); - } - private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) { - if (app == null) return Optional.empty(); - Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo); - return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get())); + onnxModel.setModelInfo(onnxModelInfo); + } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index d5c5183b01f..00797876395 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -1,13 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; +import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; -import java.util.logging.Level; import com.yahoo.log.LogMessage; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver; +import com.yahoo.vespa.defaults.Defaults; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; @@ -30,9 +31,11 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.io.File; import java.io.IOException; import java.nio.file.Files; +import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; +import java.util.logging.Level; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -152,7 +155,7 @@ public class RankSetupValidator extends Validator { if (models.values().size() > 0) { List<String> config = new ArrayList<>(models.values().size() * 2); for (OnnxModel model : models.values()) { - String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference()); + String modelPath = getFileRepositoryPath(model.getFilePath(), model.getFileReference()); config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } @@ -160,6 +163,12 @@ public class RankSetupValidator extends Validator { } } + public static String getFileRepositoryPath(Path path, String fileReference) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); + return Paths.get(fileRefDir, fileReference, path.getName()).toString(); + } + private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException { IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java new file mode 100644 index 00000000000..7526a8a8595 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -0,0 +1,389 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.TensorType; +import onnx.Onnx; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Model information (input and output types) for an ONNX model. + * This encapsulates the difference between reading ONNX model information + * - from a file application package, where we can read the ONNX model directly + * - from a ZK application package, where the file is unavailable and models are read from + * generated files stored in file distribution or ZooKeeper. + * + * @author lesters + */ +public class OnnxModelInfo { + + private final String defaultOutput; + private final Map<String, OnnxTypeInfo> inputs; + private final Map<String, OnnxTypeInfo> outputs; + private final Map<String, TensorType> vespaTypes = new HashMap<>(); + + private OnnxModelInfo(Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + this.inputs = Collections.unmodifiableMap(inputs); + this.outputs = Collections.unmodifiableMap(outputs); + this.defaultOutput = defaultOutput; + } + + public Set<String> getInputs() { + return inputs.keySet(); + } + + public Set<String> getOutputs() { + return outputs.keySet(); + } + + public String getDefaultOutput() { + return defaultOutput; + } + + /** + * Return the tensor type for an ONNX model for the given context. + * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output + * type depends on the input types for the given context (rank profile). + */ + public TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { + OnnxTypeInfo onnxTypeInfo = outputs.get(onnxName); + if (onnxTypeInfo == null) { + throw new IllegalArgumentException("Could not find type for output '" + onnxName + "'"); + } + if (onnxTypeInfo.containsUnknownDimensionSizes()) { + Set<Long> unboundSizes = new HashSet<>(); + Map<String, Long> symbolicSizes = new HashMap<>(); + resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes); + return onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes); + } + return vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType()); + } + + private void resolveUnknownDimensionSizes(Map<String, TensorType> inputTypes, + Map<String, Long> symbolicSizes, + Set<Long> unboundSizes) + { + for (Map.Entry<String, OnnxTypeInfo> input : inputs.entrySet()) { + String onnxName = input.getKey(); + OnnxTypeInfo onnxType = input.getValue(); + TensorType vespaType = inputTypes.get(onnxName); + if (vespaType == null || vespaType.dimensions().size() != onnxType.dimensions().size()) { + continue; + } + + for (int i = 0; i < vespaType.dimensions().size(); ++i) { + if (vespaType.dimensions().get(i).size().isEmpty()) { + continue; + } + Long size = vespaType.dimensions().get(i).size().get(); + + // Handle dimensions with size -1 - typically batch dimensions + if (onnxType.dimensions().get(i).getSize() == -1) { + unboundSizes.add(size); + if (unboundSizes.size() > 1) { + throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " + + "for type '" + onnxType + "'"); + } + + // Handle dimensions with symbolic names + } else if (onnxType.dimensions().get(i).hasSymbolicName()) { + String symbolicName = onnxType.dimensions().get(i).getSymbolicName(); + if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { + throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + + symbolicName + "' for input '" + onnxName + "'"); + } + symbolicSizes.put(symbolicName, size); + } + } + } + } + + static public OnnxModelInfo load(String path, ApplicationPackage app) { + Path pathInApplicationPackage = Path.fromString(path); + if (app.getFile(pathInApplicationPackage).exists()) { + return loadFromFile(pathInApplicationPackage, app); + } + if (app.getFile(generatedModelInfoPath(pathInApplicationPackage)).exists()) { + return loadFromGeneratedInfo(pathInApplicationPackage, app); + } + throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file"); + } + + static public boolean modelExists(String path, ApplicationPackage app) { + Path pathInApplicationPackage = Path.fromString(path); + if (app.getFile(pathInApplicationPackage).exists()) { + return true; + } + if (app.getFile(generatedModelInfoPath(Path.fromString(path))).exists()) { + return true; + } + return false; + } + + static private OnnxModelInfo loadFromFile(Path path, ApplicationPackage app) { + try (InputStream inputStream = app.getFile(path).createInputStream()) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + String json = onnxModelToJson(model); + storeGeneratedInfo(json, path, app); + return jsonToModelInfo(json); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse ONNX model", e); + } + } + + static private OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) { + try { + String json = readGeneratedInfo(path, app); + return jsonToModelInfo(json); + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse ONNX model", e); + } + } + + static private String readGeneratedInfo(Path path, ApplicationPackage app) throws IOException { + ApplicationFile file = app.getFile(generatedModelInfoPath(path)); + return IOUtils.readAll(file.createReader()); + } + + static private void storeGeneratedInfo(String json, Path path, ApplicationPackage app) throws IOException { + IOUtils.writeFile(app.getFileReference(generatedModelInfoPath(path)), json, false); + } + + static private Path generatedModelInfoPath(Path path) { + String fileName = asValidIdentifier(path.getRelative()) + ".modelinfo.json"; + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); + } + + static private String onnxModelToJson(Onnx.ModelProto model) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); + g.writeStartObject(); + + g.writeArrayFieldStart("inputs"); + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { + onnxTypeToJson(g, valueInfo); + } + g.writeEndArray(); + + g.writeArrayFieldStart("outputs"); + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { + onnxTypeToJson(g, valueInfo); + } + g.writeEndArray(); + + g.writeEndObject(); + g.close(); + return out.toString(); + } + + static public OnnxModelInfo jsonToModelInfo(String json) throws IOException { + ObjectMapper m = new ObjectMapper(); + JsonNode root = m.readTree(json); + Map<String, OnnxTypeInfo> inputs = new HashMap<>(); + Map<String, OnnxTypeInfo> outputs = new HashMap<>(); + String defaultOutput = ""; + + for (JsonNode input : root.get("inputs")) { + inputs.put(input.get("name").textValue(), jsonToTypeInfo(input)); + } + for (JsonNode output : root.get("outputs")) { + outputs.put(output.get("name").textValue(), jsonToTypeInfo(output)); + } + if (root.get("outputs").has(0)) { + defaultOutput = root.get("outputs").get(0).get("name").textValue(); + } + return new OnnxModelInfo(inputs, outputs, defaultOutput); + } + + static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { + g.writeStartObject(); + g.writeStringField("name", valueInfo.getName()); + g.writeStringField("type", onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType())); + g.writeArrayFieldStart("dim"); + for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) { + g.writeStartObject(); + if (dim.hasDimParam()) { + g.writeStringField("type", "param"); + g.writeStringField("size", dim.getDimParam()); + } else { + g.writeStringField("type", "value"); + g.writeNumberField("size", dim.getDimValue()); + } + g.writeEndObject(); + } + g.writeEndArray(); + g.writeEndObject(); + } + + static private OnnxTypeInfo jsonToTypeInfo(JsonNode node) { + TensorType.Value valueType = stringToValueType(node.get("type").textValue()); + OnnxTypeInfo type = new OnnxTypeInfo(valueType); + for (JsonNode dim : node.get("dim")) { + if (dim.get("type").textValue().equals("param")) { + type.addDimension(dim.get("size").textValue()); + } else { + type.addDimension(dim.get("size").longValue()); + } + } + return type; + } + + private static String onnxValueTypeToString(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return "float"; + case DOUBLE: return "double"; + // Imperfect conversion, for now: + case BOOL: return "float"; + case INT8: return "float"; + case INT16: return "float"; + case INT32: return "float"; + case INT64: return "float"; + case UINT8: return "float"; + case UINT16: return "float"; + case UINT32: return "float"; + case UINT64: return "float"; + default: + throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + + private static TensorType.Value stringToValueType(String type) { + switch (type) { + case "float": return TensorType.Value.FLOAT; + case "double": return TensorType.Value.DOUBLE; + default: + throw new IllegalArgumentException("Unknown tensor value type: " + type); + } + } + + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); + } + + + private static class OnnxTypeInfo { + private final TensorType.Value valueType; + private final List<OnnxDimensionInfo> dimensions = new ArrayList<>(); + + OnnxTypeInfo(TensorType.Value valueType) { + this.valueType = valueType; + } + + void addDimension(long value) { + dimensions.add(new OnnxDimensionInfo(value)); + } + + void addDimension(String param) { + dimensions.add(new OnnxDimensionInfo(param)); + } + + boolean containsUnknownDimensionSizes() { + return dimensions.stream().anyMatch(OnnxDimensionInfo::unknownDimensionSize); + } + + TensorType.Value valueType() { + return valueType; + } + + List<OnnxDimensionInfo> dimensions() { + return dimensions; + } + + TensorType toVespaTensorType() { + return toVespaTensorType(null, null); + } + + TensorType toVespaTensorType(Map<String, Long> symbolicSizes, Set<Long> unboundSizes) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... + TensorType.Builder builder = new TensorType.Builder(valueType); + for (int i = 0; i < dimensions.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + OnnxDimensionInfo onnxDimension = dimensions.get(i); + long onnxDimensionSize = onnxDimension.getSize(); + if (onnxDimension.hasSymbolicName() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getSymbolicName())) { + onnxDimensionSize = symbolicSizes.get(onnxDimension.getSymbolicName()); + } + if (onnxDimensionSize == 0 && symbolicSizes != null) { + // This is for the case where all symbolic dimensions have + // different names, but can be resolved to a single dimension size. + Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); + if (unknownSizes.size() == 1) { + onnxDimensionSize = unknownSizes.iterator().next(); + } + } + if (onnxDimensionSize < 0 && unboundSizes != null && unboundSizes.size() > 0) { + onnxDimensionSize = unboundSizes.iterator().next(); + } + if (onnxDimensionSize <= 0) { + return TensorType.empty; // Unable to determine type - probably out of context + } + builder.indexed(dimensionName, onnxDimensionSize); + } + return builder.build(); + } + + @Override + public String toString() { + return "(" + valueType.id() + ")" + + "[" + dimensions.stream().map(OnnxDimensionInfo::toString).collect(Collectors.joining(",")) + "]"; + } + + } + + private static class OnnxDimensionInfo { + private final long size; + private final String symbolicName; + + OnnxDimensionInfo(long size) { + this.size = size; + this.symbolicName = null; + } + + OnnxDimensionInfo(String symbolicName) { + this.size = 0; + this.symbolicName = symbolicName; + } + + long getSize() { + return size; + } + + String getSymbolicName() { + return symbolicName; + } + + boolean hasSymbolicName() { + return symbolicName != null; + } + + boolean unknownDimensionSize() { + return hasSymbolicName() || size <= 0; + } + + @Override + public String toString() { + return hasSymbolicName() ? "\"" + symbolicName + "\"" : Long.toString(size); + } + } + +} diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/schemas/test.sd index a87222e77ee..a87222e77ee 100644 --- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd +++ b/config-model/src/test/integration/onnx-model/schemas/test.sd diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index 4eb8681c374..f8a379b4027 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -1,27 +1,61 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.search.DocumentDatabase; import com.yahoo.vespa.model.search.IndexedSearchCluster; -import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; +import org.junit.After; import org.junit.Test; import static org.junit.Assert.assertEquals; public class RankingExpressionWithOnnxModelTestCase { + private final Path applicationDir = Path.fromString("src/test/integration/onnx-model/"); + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + @Test - public void testOnnxModelFeature() { - VespaModel model = new VespaModelCreatorWithFilePkg("src/test/integration/onnx-model").create(); - DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); - assertTransformedFeature(db); - assertGeneratedConfig(db); + public void testOnnxModelFeature() throws Exception { + VespaModel model = loadModel(applicationDir); + assertTransformedFeature(model); + assertGeneratedConfig(model); + + Path storedApplicationDir = applicationDir.append("copy"); + try { + storedApplicationDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedApplicationDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append("schemas").toFile(), storedApplicationDir.append("schemas").toFile()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + + VespaModel storedModel = loadModel(storedApplicationDir); + assertTransformedFeature(storedModel); + assertGeneratedConfig(storedModel); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDir.toFile()); + } } - private void assertGeneratedConfig(DocumentDatabase db) { + private VespaModel loadModel(Path path) throws Exception { + FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile()); + DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).build(); + return new VespaModel(state); + } + + private void assertGeneratedConfig(VespaModel model) { + DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); @@ -72,10 +106,10 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals(1, config.model(4).output().size()); assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); - } - private void assertTransformedFeature(DocumentDatabase db) { + private void assertTransformedFeature(VespaModel model) { + DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); |