diff options
author | Lester Solbakken <lesters@oath.com> | 2020-10-25 12:26:56 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-10-25 12:26:56 +0100 |
commit | 38e2a6a325db457456e04ce8385f23b12a5da54d (patch) | |
tree | e5e5906f0692831240bd898c9378e948c68a5d02 /config-model/src/main/java | |
parent | 899f7210569b4f43c1531a4f4c12507b41a7f4f7 (diff) |
Revert "Revert "Add type resolving for ONNX models""
This reverts commit 882d574ab53e8d10a2a8765a64487c20661dc63f.
Diffstat (limited to 'config-model/src/main/java')
12 files changed, 582 insertions, 102 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 4011ce43841..b153ff62e7d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); } + // A reference to an ONNX model? + Optional<TensorType> onnxFeatureType = onnxFeatureType(reference); + if (onnxFeatureType.isPresent()) { + return onnxFeatureType.get(); + } + // A reference to a feature which returns a tensor? Optional<TensorType> featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(function); } + private Optional<TensorType> onnxFeatureType(Reference reference) { + if ( ! reference.name().equals("onnxModel")) + return Optional.empty(); + + if ( ! featureTypes.containsKey(reference)) { + String configOrFileName = reference.arguments().expressions().get(0).toString(); + + // Look up standardized format as added in RankProfile + String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); + String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); + + reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + if ( ! featureTypes.containsKey(reference)) { + throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); + } + } + + return Optional.of(featureTypes.get(reference)); + } + /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index c2fb2107604..58213186f78 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,14 +2,22 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; +import onnx.Onnx; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,16 +29,16 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; + private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private List<OnnxNameMapping> inputMap = new ArrayList<>(); - private List<OnnxNameMapping> outputMap = new ArrayList<>(); + private String defaultOutput = null; + private Map<String, String> inputMap = new HashMap<>(); + private Map<String, String> outputMap = new HashMap<>(); - public PathType getPathType() { - return pathType; - } - - private PathType pathType = PathType.FILE; + private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>(); + private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>(); + private Map<String, TensorType> vespaTypes = new HashMap<>(); public OnnxModel(String name) { this.name = name; @@ -49,21 +57,52 @@ public class OnnxModel { } public void setUri(String uri) { - Objects.requireNonNull(uri, "uri cannot be null"); - this.path = uri; - this.pathType = PathType.URI; + throw new IllegalArgumentException("URI for ONNX models are not currently supported"); + } + + public PathType getPathType() { + return pathType; + } + + public void setDefaultOutput(String onnxName) { + Objects.requireNonNull(onnxName, "Name cannot be null"); + this.defaultOutput = onnxName; } public void addInputNameMapping(String onnxName, String vespaName) { + addInputNameMapping(onnxName, vespaName, true); + } + + public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! inputMap.containsKey(onnxName)) { + inputMap.put(onnxName, vespaName); + } } public void addOutputNameMapping(String onnxName, String vespaName) { + addOutputNameMapping(onnxName, vespaName, true); + } + + public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! outputMap.containsKey(onnxName)) { + outputMap.put(onnxName, vespaName); + } + } + + public void addInputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + inputTypes.put(onnxName, type); + } + + public void addOutputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + outputTypes.put(onnxName, type); } /** Initiate sending of this constant to some services over file distribution */ @@ -76,11 +115,16 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } + public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } - public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } + public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + + public String getDefaultOutput() { + return defaultOutput; + } public void validate() { if (path == null || path.isEmpty()) @@ -90,23 +134,142 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - public static class OnnxNameMapping { - private String onnxName; - private String vespaName; + /** + * Return the tensor type for an ONNX model for the given context. + * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output + * type depends on the input types for the given context (rank profile). + */ + public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { + Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); + if (onnxOutputType == null) { + throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); + } + if (containsSymbolicDimensionSizes(onnxOutputType)) { + return getTensorTypeWithSymbolicDimensions(onnxOutputType, context); + } + return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); + } + + private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { + Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context); + if (symbolicSizes.isEmpty()) { + return TensorType.empty; // Context is probably a rank profile not using this ONNX model + } + return typeFrom(onnxOutputType, symbolicSizes); + } + + private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) { + Map<String, Long> symbolicSizes = new HashMap<>(); + for (String onnxInputName : inputTypes.keySet()) { + + Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); + if ( ! containsSymbolicDimensionSizes(onnxType)) { + continue; + } + + Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); + if (vespaType.isEmpty()) { + return Collections.emptyMap(); + } + + var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); + var vespaDimensions = vespaType.get().dimensions(); + if (vespaDimensions.size() != onnxDimensions.size()) { + return Collections.emptyMap(); + } + + for (int i = 0; i < vespaDimensions.size(); ++i) { + if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) { + continue; + } + String symbolicName = onnxDimensions.get(i).getDimParam(); + Long size = vespaDimensions.get(i).size().get(); + if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { + throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " + + "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); + } + symbolicSizes.put(symbolicName, size); + } + } + return symbolicSizes; + } + + private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { + String source = inputMap.get(onnxInputName); + if (source != null) { + // Source is either a simple reference (query/attribute/constant)... + Optional<Reference> reference = Reference.simple(source); + if (reference.isPresent()) { + return Optional.of(context.getType(reference.get())); + } + // ... or a function + ExpressionFunction func = context.getFunction(source); + if (func != null) { + return Optional.of(func.getBody().type(context)); + } + } + return Optional.empty(); // if this context does not contain this input + } + + private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) { + return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue()); + } + + private static TensorType typeFrom(Onnx.TypeProto type) { + return typeFrom(type, null); + } + + private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + long onnxDimensionSize = onnxDimension.getDimValue(); + if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { + onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); + } + if (onnxDimensionSize == 0 && symbolicSizes != null) { + // This is for the case where all symbolic dimensions have + // different names, but can be resolved to a single dimension size. + Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); + if (unknownSizes.size() == 1) { + onnxDimensionSize = unknownSizes.iterator().next(); + } + } + if (onnxDimensionSize <= 0) { + throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + + "ONNX type: " + type + " to Vespa tensor type."); + } + builder.indexed(dimensionName, onnxDimensionSize); + } + return builder.build(); + } - private OnnxNameMapping(String onnxName, String vespaName) { - this.onnxName = onnxName; - this.vespaName = vespaName; + private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.FLOAT; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); } - public String getOnnxName() { return onnxName; } - public String getVespaName() { return vespaName; } - public void setVespaName(String vespaName) { this.vespaName = vespaName; } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index d309f48d6df..96c043bdb34 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; @@ -158,6 +159,10 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } + private Map<String, OnnxModel> onnxModels() { + return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); + } + private Stream<ImmutableSDField> allFields() { if (search == null) return Stream.empty(); if (allFieldsList == null) { @@ -821,6 +826,20 @@ public class RankProfile implements Cloneable { } } + // Add output types for ONNX models + for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { + String modelName = entry.getKey(); + OnnxModel model = entry.getValue(); + Arguments args = new Arguments(new ReferenceNode(modelName)); + + TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context); + context.setType(new Reference("onnxModel", args, null), defaultOutputType); + + for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { + TensorType type = model.getTensorType(mapping.getKey(), context); + context.setType(new Reference("onnxModel", args, mapping.getValue()), type); + } + } return context; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 84442fedc48..22a32c8fd65 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); modelBuilder.name(model.getName()); modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); - model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); + model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 87eaaf0387a..56a5d539906 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { - for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { - String source = mapping.getVespaName(); + for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { + String source = mapping.getValue(); if (functionNames.contains(source)) { - mapping.setVespaName("rankingExpression(" + source + ")"); + onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); } } } @@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>(); for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); - ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); + ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile); if (referenceNode != replacedNode) { replacedSummaryFeatures.add(replacedNode); i.remove(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index ec517768ea9..d23a8376e7a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if ( ! feature.getName().equals("onnx")) return feature; + if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature; try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index e1ad003e5bd..69cdae10e47 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,20 +1,36 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; +import com.yahoo.vespa.model.ml.ModelName; import java.util.List; /** - * Transforms instances of the onnxModel ranking feature and generates - * ONNX configuration if necessary. + * Transforms ONNX model features of the forms: + * + * onnxModel(config_name) + * onnxModel(config_name).output + * onnxModel("path/to/model") + * onnxModel("path/to/model").output + * onnxModel("path/to/model", "path/to/output") + * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused + * + * To the format expected by the backend: + * + * onnxModel(config_name).output * * @author lesters */ @@ -33,85 +49,92 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { if (context.rankProfile() == null) return feature; if (context.rankProfile().getSearch() == null) return feature; - return transformFeature(feature, context.rankProfile().getSearch()); + return transformFeature(feature, context.rankProfile()); } - public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { - if (!feature.getName().equals("onnxModel")) return feature; + public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) { + ImmutableSearch search = rankProfile.getSearch(); + final String featureName = feature.getName(); + if ( ! featureName.equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + - "onnx-model config or a ONNX file."); - if (arguments.expressions().size() > 2) - throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); - - // Validation that the file actually exists is handled when the file is added to file distribution. - // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. - - String modelConfigName; - OnnxModel onnxModel; - if (arguments.expressions().get(0) instanceof ReferenceNode) { - modelConfigName = arguments.expressions().get(0).toString(); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); - } - } else if (arguments.expressions().get(0) instanceof ConstantNode) { + throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " + + "onnx-model config or an ONNX file."); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments."); + + // Check that the model configuration "onnx-model" exists. If not defined, it should have been added + // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find + // the actual ONNX file, which can happen if we are restarting or upgrading an application using an + // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. + + String modelConfigName = getModelConfigName(feature.reference()); + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { String path = asString(arguments.expressions().get(0)); - modelConfigName = asValidIdentifier(path); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - search.onnxModels().add(onnxModel); - } - } else { - throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); + ModelName modelName = new ModelName(null, Path.fromString(path), true); + ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile); + FeatureArguments featureArguments = new FeatureArguments(arguments); + return convertedModel.expression(featureArguments, null); } - String output = null; - if (feature.getOutput() != null) { - output = feature.getOutput(); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(output, output); + String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()); + String output = getModelOutput(feature.reference(), defaultOutput); + if (! onnxModel.getOutputMap().containsValue(output)) { + throw new IllegalArgumentException(featureName + " argument '" + output + + "' output not found in model '" + onnxModel.getFileName() + "'"); + } + return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); + } + + public static String getModelConfigName(Reference reference) { + if (reference.arguments().size() > 0) { + ExpressionNode expr = reference.arguments().expressions().get(0); + if (expr instanceof ReferenceNode) { // refers to onnx-model config + return expr.toString(); } - } else if (arguments.expressions().size() > 1) { - String name = asString(arguments.expressions().get(1)); - output = asValidIdentifier(name); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(name, output); + if (expr instanceof ConstantNode) { // refers to an file path + return asValidIdentifier(expr); } } + return null; + } - // Replace feature with name of config - ExpressionNode argument = new ReferenceNode(modelConfigName); - return new ReferenceNode("onnxModel", List.of(argument), output); - + public static String getModelOutput(Reference reference, String defaultOutput) { + if (reference.output() != null) { + return reference.output(); + } else if (reference.arguments().expressions().size() == 2) { + return asValidIdentifier(reference.arguments().expressions().get(1)); + } else if (reference.arguments().expressions().size() > 2) { + return asValidIdentifier(reference.arguments().expressions().get(2)); + } + return defaultOutput; } - private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { - return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); + public static String stripQuotes(String s) { + if (isNotQuoteSign(s.codePointAt(0))) return s; + if (isNotQuoteSign(s.codePointAt(s.length() - 1))) + throw new IllegalArgumentException("argument [" + s + "] is missing end quote"); + return s.substring(1, s.length()-1); } - private static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } - private static String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); + private static String asValidIdentifier(ExpressionNode node) { + return asValidIdentifier(asString(node)); } - private static boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; + private static boolean isNotQuoteSign(int c) { + return c != '\'' && c != '"'; } - private static String asValidIdentifier(String str) { - return str.replaceAll("[^\\w\\d\\$@_]", "_"); + public static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java new file mode 100644 index 00000000000..afba88c135d --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -0,0 +1,97 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.vespa.model.container.search.QueryProfiles; + +import java.util.Map; + +/** + * Processes ONNX ranking features of the form: + * + * onnx("files/model.onnx", "path/to/output:1") + * + * And generates an "onnx-model" configuration as if it was defined in the schema: + * + * onnx-model files_model_onnx { + * file: "files/model.onnx" + * } + * + * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be + * processed after this. + * + * @author lesters + */ +public class OnnxModelConfigGenerator extends Processor { + + public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { + if (profile.getFirstPhaseRanking() != null) { + process(profile.getFirstPhaseRanking().getRoot()); + } + if (profile.getSecondPhaseRanking() != null) { + process(profile.getSecondPhaseRanking().getRoot()); + } + for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { + process(function.getValue().function().getBody().getRoot()); + } + for (ReferenceNode feature : profile.getSummaryFeatures()) { + process(feature); + } + } + } + + private void process(ExpressionNode node) { + if (node instanceof ReferenceNode) { + process((ReferenceNode)node); + } else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode) node).children()) { + process(child); + } + } + } + + private void process(ReferenceNode feature) { + if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) { + if (feature.getArguments().size() > 0) { + if (feature.getArguments().expressions().get(0) instanceof ConstantNode) { + ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0); + String path = OnnxModelTransformer.stripQuotes(node.sourceString()); + String modelConfigName = OnnxModelTransformer.asValidIdentifier(path); + + // Only add the configuration if the model can actually be found. + if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + path = ApplicationPackage.MODELS_DIR.append(path).toString(); + if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + return; + } + } + + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java new file mode 100644 index 00000000000..bead2e7e7c9 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -0,0 +1,154 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.component.Version; +import com.yahoo.config.FileReference; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.vespa.defaults.Defaults; +import com.yahoo.vespa.model.container.search.QueryProfiles; +import onnx.Onnx; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Paths; +import java.util.Map; +import java.util.Optional; + +/** + * Processes every "onnx-model" element in the schema. Parses the model file, + * adds missing input and output mappings (assigning default names), and + * adds tensor types to all model inputs and outputs. + * + * Must be processed before RankingExpressingTypeResolver. + * + * @author lesters + */ +public class OnnxModelTypeResolver extends Processor { + + public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + + for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) { + OnnxModel modelConfig = entry.getValue(); + try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + + // Model inputs - if not defined, assumes a function is provided with a valid name + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { + String onnxInputName = valueInfo.getName(); + String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName); + modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false); + modelConfig.addInputType(onnxInputName, valueInfo.getType()); + } + + // Model outputs + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { + String onnxOutputName = valueInfo.getName(); + String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName); + modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false); + modelConfig.addOutputType(onnxOutputName, valueInfo.getType()); + } + + // Set the first output as default + if ( ! model.getGraph().getOutputList().isEmpty()) { + modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName()); + } + + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse ONNX model", e); + } + } + } + + static boolean modelFileExists(String path, ApplicationPackage app) { + Path pathInApplicationPackage = Path.fromString(path); + if (getFile(pathInApplicationPackage, app).exists()) { + return true; + } + if (getFileReference(pathInApplicationPackage, app).isPresent()) { + return true; + } + return false; + } + + private InputStream openModelFile(Path path) throws FileNotFoundException { + ApplicationFile file; + Optional<FileReference> reference; + Path modelsPath = ApplicationPackage.MODELS_DIR.append(path); + + if ((file = getFile(path)).exists()) { + return file.createInputStream(); + } + if ((file = getFile(modelsPath)).exists()) { + return file.createInputStream(); + } + if ((reference = getFileReference(path)).isPresent()) { + return openFromFileRepository(path, reference.get()); + } + if ((reference = getFileReference(modelsPath)).isPresent()) { + return openFromFileRepository(modelsPath, reference.get()); + } + + throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " + + "in application package or file repository."); + } + + private ApplicationFile getFile(Path path) { + return getFile(path, search.applicationPackage()); + } + + private static ApplicationFile getFile(Path path, ApplicationPackage app) { + return app.getFile(path); + } + + private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException { + return new FileInputStream(new File(getFileRepositoryPath(path, reference.value()))); + } + + public static String getFileRepositoryPath(Path path, String fileReference) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); + return Paths.get(fileRefDir, fileReference, path.getName()).toString(); + } + + private Optional<FileReference> getFileReference(Path path) { + return getFileReference(path, search.applicationPackage()); + } + + private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) { + Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app); + if (fileRegistry.isPresent()) { + for (FileRegistry.Entry file : fileRegistry.get().export()) { + if (file.relativePath.equals(path.toString())) { + return Optional.of(file.reference); + } + } + } + return Optional.empty(); + } + + private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) { + if (app == null) return Optional.empty(); + Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo); + return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get())); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index e8594c2a87f..1a3ef9e54b4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -74,6 +74,8 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, + OnnxModelConfigGenerator::new, + OnnxModelTypeResolver::new, RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index c6c7969e466..d5c5183b01f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -1,20 +1,19 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; -import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; import java.util.logging.Level; import com.yahoo.log.LogMessage; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.collections.Pair; import com.yahoo.config.ConfigInstance; -import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.config.search.ImportedFieldsConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -31,7 +30,6 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.io.File; import java.io.IOException; import java.nio.file.Files; -import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; @@ -152,12 +150,9 @@ public class RankSetupValidator extends Validator { // Assist verify-ranksetup in finding the actual ONNX model files Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); if (models.values().size() > 0) { - ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults - String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); List<String> config = new ArrayList<>(models.values().size() * 2); for (OnnxModel model : models.values()) { - String modelFilename = Paths.get(model.getFileName()).getFileName().toString(); - String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString(); + String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference()); config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 943fcbf6c1d..5ee6ed02e61 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -150,7 +150,7 @@ public class ConvertedModel { */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { ExpressionFunction expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we should verify + if (sourceModel.isPresent() && context != null) // we should verify verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); return expression.getBody().getRoot(); } |