diff options
Diffstat (limited to 'config-model')
23 files changed, 1288 insertions, 149 deletions
diff --git a/config-model/pom.xml b/config-model/pom.xml index 95e79fd09fb..c0751431d03 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -46,6 +46,11 @@ <scope>test</scope> </dependency> <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> + <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <scope>provided</scope> @@ -498,6 +503,10 @@ <updateReleaseInfo>true</updateReleaseInfo> </configuration> </plugin> + <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + </plugin> </plugins> </build> </project> diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 4011ce43841..b153ff62e7d 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); } + // A reference to an ONNX model? + Optional<TensorType> onnxFeatureType = onnxFeatureType(reference); + if (onnxFeatureType.isPresent()) { + return onnxFeatureType.get(); + } + // A reference to a feature which returns a tensor? Optional<TensorType> featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(function); } + private Optional<TensorType> onnxFeatureType(Reference reference) { + if ( ! reference.name().equals("onnxModel")) + return Optional.empty(); + + if ( ! featureTypes.containsKey(reference)) { + String configOrFileName = reference.arguments().expressions().get(0).toString(); + + // Look up standardized format as added in RankProfile + String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); + String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); + + reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + if ( ! featureTypes.containsKey(reference)) { + throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); + } + } + + return Optional.of(featureTypes.get(reference)); + } + /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index c2fb2107604..58213186f78 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,14 +2,22 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; +import onnx.Onnx; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,16 +29,16 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; + private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private List<OnnxNameMapping> inputMap = new ArrayList<>(); - private List<OnnxNameMapping> outputMap = new ArrayList<>(); + private String defaultOutput = null; + private Map<String, String> inputMap = new HashMap<>(); + private Map<String, String> outputMap = new HashMap<>(); - public PathType getPathType() { - return pathType; - } - - private PathType pathType = PathType.FILE; + private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>(); + private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>(); + private Map<String, TensorType> vespaTypes = new HashMap<>(); public OnnxModel(String name) { this.name = name; @@ -49,21 +57,52 @@ public class OnnxModel { } public void setUri(String uri) { - Objects.requireNonNull(uri, "uri cannot be null"); - this.path = uri; - this.pathType = PathType.URI; + throw new IllegalArgumentException("URI for ONNX models are not currently supported"); + } + + public PathType getPathType() { + return pathType; + } + + public void setDefaultOutput(String onnxName) { + Objects.requireNonNull(onnxName, "Name cannot be null"); + this.defaultOutput = onnxName; } public void addInputNameMapping(String onnxName, String vespaName) { + addInputNameMapping(onnxName, vespaName, true); + } + + public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! inputMap.containsKey(onnxName)) { + inputMap.put(onnxName, vespaName); + } } public void addOutputNameMapping(String onnxName, String vespaName) { + addOutputNameMapping(onnxName, vespaName, true); + } + + public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! outputMap.containsKey(onnxName)) { + outputMap.put(onnxName, vespaName); + } + } + + public void addInputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + inputTypes.put(onnxName, type); + } + + public void addOutputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + outputTypes.put(onnxName, type); } /** Initiate sending of this constant to some services over file distribution */ @@ -76,11 +115,16 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } + public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } - public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } + public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + + public String getDefaultOutput() { + return defaultOutput; + } public void validate() { if (path == null || path.isEmpty()) @@ -90,23 +134,142 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - public static class OnnxNameMapping { - private String onnxName; - private String vespaName; + /** + * Return the tensor type for an ONNX model for the given context. + * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output + * type depends on the input types for the given context (rank profile). + */ + public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { + Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); + if (onnxOutputType == null) { + throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); + } + if (containsSymbolicDimensionSizes(onnxOutputType)) { + return getTensorTypeWithSymbolicDimensions(onnxOutputType, context); + } + return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); + } + + private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { + Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context); + if (symbolicSizes.isEmpty()) { + return TensorType.empty; // Context is probably a rank profile not using this ONNX model + } + return typeFrom(onnxOutputType, symbolicSizes); + } + + private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) { + Map<String, Long> symbolicSizes = new HashMap<>(); + for (String onnxInputName : inputTypes.keySet()) { + + Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); + if ( ! containsSymbolicDimensionSizes(onnxType)) { + continue; + } + + Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); + if (vespaType.isEmpty()) { + return Collections.emptyMap(); + } + + var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); + var vespaDimensions = vespaType.get().dimensions(); + if (vespaDimensions.size() != onnxDimensions.size()) { + return Collections.emptyMap(); + } + + for (int i = 0; i < vespaDimensions.size(); ++i) { + if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) { + continue; + } + String symbolicName = onnxDimensions.get(i).getDimParam(); + Long size = vespaDimensions.get(i).size().get(); + if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { + throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " + + "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); + } + symbolicSizes.put(symbolicName, size); + } + } + return symbolicSizes; + } + + private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { + String source = inputMap.get(onnxInputName); + if (source != null) { + // Source is either a simple reference (query/attribute/constant)... + Optional<Reference> reference = Reference.simple(source); + if (reference.isPresent()) { + return Optional.of(context.getType(reference.get())); + } + // ... or a function + ExpressionFunction func = context.getFunction(source); + if (func != null) { + return Optional.of(func.getBody().type(context)); + } + } + return Optional.empty(); // if this context does not contain this input + } + + private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) { + return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue()); + } + + private static TensorType typeFrom(Onnx.TypeProto type) { + return typeFrom(type, null); + } + + private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + long onnxDimensionSize = onnxDimension.getDimValue(); + if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { + onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); + } + if (onnxDimensionSize == 0 && symbolicSizes != null) { + // This is for the case where all symbolic dimensions have + // different names, but can be resolved to a single dimension size. + Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); + if (unknownSizes.size() == 1) { + onnxDimensionSize = unknownSizes.iterator().next(); + } + } + if (onnxDimensionSize <= 0) { + throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + + "ONNX type: " + type + " to Vespa tensor type."); + } + builder.indexed(dimensionName, onnxDimensionSize); + } + return builder.build(); + } - private OnnxNameMapping(String onnxName, String vespaName) { - this.onnxName = onnxName; - this.vespaName = vespaName; + private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.FLOAT; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); } - public String getOnnxName() { return onnxName; } - public String getVespaName() { return vespaName; } - public void setVespaName(String vespaName) { this.vespaName = vespaName; } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index d309f48d6df..96c043bdb34 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; @@ -158,6 +159,10 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } + private Map<String, OnnxModel> onnxModels() { + return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); + } + private Stream<ImmutableSDField> allFields() { if (search == null) return Stream.empty(); if (allFieldsList == null) { @@ -821,6 +826,20 @@ public class RankProfile implements Cloneable { } } + // Add output types for ONNX models + for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { + String modelName = entry.getKey(); + OnnxModel model = entry.getValue(); + Arguments args = new Arguments(new ReferenceNode(modelName)); + + TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context); + context.setType(new Reference("onnxModel", args, null), defaultOutputType); + + for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { + TensorType type = model.getTensorType(mapping.getKey(), context); + context.setType(new Reference("onnxModel", args, mapping.getValue()), type); + } + } return context; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 84442fedc48..22a32c8fd65 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); modelBuilder.name(model.getName()); modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); - model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); + model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 87eaaf0387a..56a5d539906 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { - for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { - String source = mapping.getVespaName(); + for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { + String source = mapping.getValue(); if (functionNames.contains(source)) { - mapping.setVespaName("rankingExpression(" + source + ")"); + onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); } } } @@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>(); for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); - ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); + ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile); if (referenceNode != replacedNode) { replacedSummaryFeatures.add(replacedNode); i.remove(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index ec517768ea9..d23a8376e7a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if ( ! feature.getName().equals("onnx")) return feature; + if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature; try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index e1ad003e5bd..69cdae10e47 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,20 +1,36 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; +import com.yahoo.vespa.model.ml.ModelName; import java.util.List; /** - * Transforms instances of the onnxModel ranking feature and generates - * ONNX configuration if necessary. + * Transforms ONNX model features of the forms: + * + * onnxModel(config_name) + * onnxModel(config_name).output + * onnxModel("path/to/model") + * onnxModel("path/to/model").output + * onnxModel("path/to/model", "path/to/output") + * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused + * + * To the format expected by the backend: + * + * onnxModel(config_name).output * * @author lesters */ @@ -33,85 +49,92 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { if (context.rankProfile() == null) return feature; if (context.rankProfile().getSearch() == null) return feature; - return transformFeature(feature, context.rankProfile().getSearch()); + return transformFeature(feature, context.rankProfile()); } - public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { - if (!feature.getName().equals("onnxModel")) return feature; + public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) { + ImmutableSearch search = rankProfile.getSearch(); + final String featureName = feature.getName(); + if ( ! featureName.equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + - "onnx-model config or a ONNX file."); - if (arguments.expressions().size() > 2) - throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); - - // Validation that the file actually exists is handled when the file is added to file distribution. - // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. - - String modelConfigName; - OnnxModel onnxModel; - if (arguments.expressions().get(0) instanceof ReferenceNode) { - modelConfigName = arguments.expressions().get(0).toString(); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); - } - } else if (arguments.expressions().get(0) instanceof ConstantNode) { + throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " + + "onnx-model config or an ONNX file."); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments."); + + // Check that the model configuration "onnx-model" exists. If not defined, it should have been added + // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find + // the actual ONNX file, which can happen if we are restarting or upgrading an application using an + // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. + + String modelConfigName = getModelConfigName(feature.reference()); + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { String path = asString(arguments.expressions().get(0)); - modelConfigName = asValidIdentifier(path); - onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - search.onnxModels().add(onnxModel); - } - } else { - throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); + ModelName modelName = new ModelName(null, Path.fromString(path), true); + ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile); + FeatureArguments featureArguments = new FeatureArguments(arguments); + return convertedModel.expression(featureArguments, null); } - String output = null; - if (feature.getOutput() != null) { - output = feature.getOutput(); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(output, output); + String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()); + String output = getModelOutput(feature.reference(), defaultOutput); + if (! onnxModel.getOutputMap().containsValue(output)) { + throw new IllegalArgumentException(featureName + " argument '" + output + + "' output not found in model '" + onnxModel.getFileName() + "'"); + } + return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); + } + + public static String getModelConfigName(Reference reference) { + if (reference.arguments().size() > 0) { + ExpressionNode expr = reference.arguments().expressions().get(0); + if (expr instanceof ReferenceNode) { // refers to onnx-model config + return expr.toString(); } - } else if (arguments.expressions().size() > 1) { - String name = asString(arguments.expressions().get(1)); - output = asValidIdentifier(name); - if ( ! hasOutputMapping(onnxModel, output)) { - onnxModel.addOutputNameMapping(name, output); + if (expr instanceof ConstantNode) { // refers to an file path + return asValidIdentifier(expr); } } + return null; + } - // Replace feature with name of config - ExpressionNode argument = new ReferenceNode(modelConfigName); - return new ReferenceNode("onnxModel", List.of(argument), output); - + public static String getModelOutput(Reference reference, String defaultOutput) { + if (reference.output() != null) { + return reference.output(); + } else if (reference.arguments().expressions().size() == 2) { + return asValidIdentifier(reference.arguments().expressions().get(1)); + } else if (reference.arguments().expressions().size() > 2) { + return asValidIdentifier(reference.arguments().expressions().get(2)); + } + return defaultOutput; } - private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { - return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); + public static String stripQuotes(String s) { + if (isNotQuoteSign(s.codePointAt(0))) return s; + if (isNotQuoteSign(s.codePointAt(s.length() - 1))) + throw new IllegalArgumentException("argument [" + s + "] is missing end quote"); + return s.substring(1, s.length()-1); } - private static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); + public static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } - private static String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); + private static String asValidIdentifier(ExpressionNode node) { + return asValidIdentifier(asString(node)); } - private static boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; + private static boolean isNotQuoteSign(int c) { + return c != '\'' && c != '"'; } - private static String asValidIdentifier(String str) { - return str.replaceAll("[^\\w\\d\\$@_]", "_"); + public static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java new file mode 100644 index 00000000000..afba88c135d --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java @@ -0,0 +1,97 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.vespa.model.container.search.QueryProfiles; + +import java.util.Map; + +/** + * Processes ONNX ranking features of the form: + * + * onnx("files/model.onnx", "path/to/output:1") + * + * And generates an "onnx-model" configuration as if it was defined in the schema: + * + * onnx-model files_model_onnx { + * file: "files/model.onnx" + * } + * + * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be + * processed after this. + * + * @author lesters + */ +public class OnnxModelConfigGenerator extends Processor { + + public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { + if (profile.getFirstPhaseRanking() != null) { + process(profile.getFirstPhaseRanking().getRoot()); + } + if (profile.getSecondPhaseRanking() != null) { + process(profile.getSecondPhaseRanking().getRoot()); + } + for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { + process(function.getValue().function().getBody().getRoot()); + } + for (ReferenceNode feature : profile.getSummaryFeatures()) { + process(feature); + } + } + } + + private void process(ExpressionNode node) { + if (node instanceof ReferenceNode) { + process((ReferenceNode)node); + } else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode) node).children()) { + process(child); + } + } + } + + private void process(ReferenceNode feature) { + if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) { + if (feature.getArguments().size() > 0) { + if (feature.getArguments().expressions().get(0) instanceof ConstantNode) { + ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0); + String path = OnnxModelTransformer.stripQuotes(node.sourceString()); + String modelConfigName = OnnxModelTransformer.asValidIdentifier(path); + + // Only add the configuration if the model can actually be found. + if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + path = ApplicationPackage.MODELS_DIR.append(path).toString(); + if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { + return; + } + } + + OnnxModel onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } + } + } + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java new file mode 100644 index 00000000000..bead2e7e7c9 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -0,0 +1,154 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchdefinition.processing; + +import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.component.Version; +import com.yahoo.config.FileReference; +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; +import com.yahoo.vespa.defaults.Defaults; +import com.yahoo.vespa.model.container.search.QueryProfiles; +import onnx.Onnx; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Paths; +import java.util.Map; +import java.util.Optional; + +/** + * Processes every "onnx-model" element in the schema. Parses the model file, + * adds missing input and output mappings (assigning default names), and + * adds tensor types to all model inputs and outputs. + * + * Must be processed before RankingExpressingTypeResolver. + * + * @author lesters + */ +public class OnnxModelTypeResolver extends Processor { + + public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + } + + @Override + public void process(boolean validate, boolean documentsOnly) { + if (documentsOnly) return; + + for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) { + OnnxModel modelConfig = entry.getValue(); + try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + + // Model inputs - if not defined, assumes a function is provided with a valid name + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { + String onnxInputName = valueInfo.getName(); + String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName); + modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false); + modelConfig.addInputType(onnxInputName, valueInfo.getType()); + } + + // Model outputs + for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { + String onnxOutputName = valueInfo.getName(); + String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName); + modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false); + modelConfig.addOutputType(onnxOutputName, valueInfo.getType()); + } + + // Set the first output as default + if ( ! model.getGraph().getOutputList().isEmpty()) { + modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName()); + } + + } catch (IOException e) { + throw new IllegalArgumentException("Unable to parse ONNX model", e); + } + } + } + + static boolean modelFileExists(String path, ApplicationPackage app) { + Path pathInApplicationPackage = Path.fromString(path); + if (getFile(pathInApplicationPackage, app).exists()) { + return true; + } + if (getFileReference(pathInApplicationPackage, app).isPresent()) { + return true; + } + return false; + } + + private InputStream openModelFile(Path path) throws FileNotFoundException { + ApplicationFile file; + Optional<FileReference> reference; + Path modelsPath = ApplicationPackage.MODELS_DIR.append(path); + + if ((file = getFile(path)).exists()) { + return file.createInputStream(); + } + if ((file = getFile(modelsPath)).exists()) { + return file.createInputStream(); + } + if ((reference = getFileReference(path)).isPresent()) { + return openFromFileRepository(path, reference.get()); + } + if ((reference = getFileReference(modelsPath)).isPresent()) { + return openFromFileRepository(modelsPath, reference.get()); + } + + throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " + + "in application package or file repository."); + } + + private ApplicationFile getFile(Path path) { + return getFile(path, search.applicationPackage()); + } + + private static ApplicationFile getFile(Path path, ApplicationPackage app) { + return app.getFile(path); + } + + private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException { + return new FileInputStream(new File(getFileRepositoryPath(path, reference.value()))); + } + + public static String getFileRepositoryPath(Path path, String fileReference) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); + return Paths.get(fileRefDir, fileReference, path.getName()).toString(); + } + + private Optional<FileReference> getFileReference(Path path) { + return getFileReference(path, search.applicationPackage()); + } + + private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) { + Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app); + if (fileRegistry.isPresent()) { + for (FileRegistry.Entry file : fileRegistry.get().export()) { + if (file.relativePath.equals(path.toString())) { + return Optional.of(file.reference); + } + } + } + return Optional.empty(); + } + + private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) { + if (app == null) return Optional.empty(); + Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo); + return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get())); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index e8594c2a87f..1a3ef9e54b4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -74,6 +74,8 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, + OnnxModelConfigGenerator::new, + OnnxModelTypeResolver::new, RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index c6c7969e466..d5c5183b01f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -1,20 +1,19 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; -import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; import java.util.logging.Level; import com.yahoo.log.LogMessage; import com.yahoo.searchdefinition.OnnxModel; +import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.collections.Pair; import com.yahoo.config.ConfigInstance; -import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.config.search.ImportedFieldsConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -31,7 +30,6 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.io.File; import java.io.IOException; import java.nio.file.Files; -import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; @@ -152,12 +150,9 @@ public class RankSetupValidator extends Validator { // Assist verify-ranksetup in finding the actual ONNX model files Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); if (models.values().size() > 0) { - ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults - String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); List<String> config = new ArrayList<>(models.values().size() * 2); for (OnnxModel model : models.values()) { - String modelFilename = Paths.get(model.getFileName()).getFileName().toString(); - String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString(); + String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference()); config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 943fcbf6c1d..5ee6ed02e61 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -150,7 +150,7 @@ public class ConvertedModel { */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { ExpressionFunction expression = selectExpression(arguments); - if (sourceModel.isPresent()) // we should verify + if (sourceModel.isPresent() && context != null) // we should verify verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); return expression.getBody().getRoot(); } diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto new file mode 100644 index 00000000000..dc6542867e0 --- /dev/null +++ b/config-model/src/main/protobuf/onnx.proto @@ -0,0 +1,464 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) Facebook Inc. and Microsoft Corporation. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. We should use version as + // xx(major) - xx(minor) - xxxx(bugfix) + // and we are starting with 0x00000001 (0.0.1), which was the + // version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x00000001; + + // IR_VERSION 0.0.2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x00000002; + + // IR VERSION 0.0.3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION = 0x00000003; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // also appears in the input list. + repeated TensorProto initializer = 5; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // Advanced types + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + optional DataType data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // For double + // Complex64 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// A set of pre-defined constants to be used as values for +// the standard denotation field in TensorShapeProto.Dimension +// for semantic description of the tensor dimension. +message DenotationConstProto { + // Describe a batch number dimension. + optional string DATA_BATCH = 1 [default = "DATA_BATCH"]; + // Describe a channel dimension. + optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"]; + // Describe a time dimension. + optional string DATA_TIME = 3 [default = "DATA_TIME"]; + // Describe a feature dimension. This is typically a feature + // dimension in RNN and/or spatial dimension in CNN. + optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"]; + // Describe a filter in-channel dimension. This is the dimension + // that is identical (in size) to the channel dimension of the input + // image feature maps. + optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"]; + // Describe a filter out channel dimension. This is the dimension + // that is identical (int size) to the channel dimension of the output + // image feature maps. + optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"]; + // Describe a filter spatial dimension. + optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"]; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST be present for this version of the IR. + optional TensorProto.DataType elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + } +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +}
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd index e9575af6010..cc73f2daff5 100644 --- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd +++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd @@ -18,7 +18,7 @@ search test { } function mnist_softmax_onnx() { - expression: onnx("mnist_softmax") + expression: onnx_vespa("mnist_softmax") } function my_xgboost() { diff --git a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py new file mode 100755 index 00000000000..55df3a557e9 --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py @@ -0,0 +1,12 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "sequence"]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "sequence"]) + +nodes = [helper.make_node('Identity', ['input'], ['output'])] +graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT]) +model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py') +onnx.save(model_def, 'dynamic_model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/create_model.py b/config-model/src/test/integration/onnx-model/files/create_model.py new file mode 100755 index 00000000000..10ff92c2eda --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/create_model.py @@ -0,0 +1,37 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('first_input', TensorProto.FLOAT, [2]) +INPUT_2 = helper.make_tensor_value_info('second/input:0', TensorProto.FLOAT, [2]) +INPUT_3 = helper.make_tensor_value_info('third_input', TensorProto.FLOAT, [2]) +OUTPUT_1 = helper.make_tensor_value_info('path/to/output:0', TensorProto.FLOAT, [2]) +OUTPUT_2 = helper.make_tensor_value_info('path/to/output:1', TensorProto.FLOAT, [2]) +OUTPUT_3 = helper.make_tensor_value_info('path/to/output:2', TensorProto.FLOAT, [2]) + +nodes = [ + helper.make_node( + 'Add', + ['first_input', 'second/input:0'], + ['path/to/output:0'], + ), + helper.make_node( + 'Add', + ['third_input', 'second/input:0'], + ['path/to/output:1'] + ), + helper.make_node( + 'Add', + ['path/to/output:0', 'path/to/output:1'], + ['path/to/output:2'] + ), +] +graph_def = helper.make_graph( + nodes, + 'simple_scoring', + [INPUT_1, INPUT_2, INPUT_3], + [OUTPUT_1, OUTPUT_2, OUTPUT_3] +) +model_def = helper.make_model(graph_def, producer_name='create_model.py') +onnx.save(model_def, 'model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx new file mode 100644 index 00000000000..6bbdad2d76e --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx @@ -0,0 +1,13 @@ +create_dynamic_model.py:x + +inputoutput"Identitysimple_scoringZ$ +input + +batch + +sequenceb% +output + +batch + +sequenceB
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/model.onnx b/config-model/src/test/integration/onnx-model/files/model.onnx new file mode 100644 index 00000000000..f3898205c6a --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/model.onnx @@ -0,0 +1,34 @@ +create_model.py:í +4 +first_input +second/input:0path/to/output:0"Add +4 +third_input +second/input:0path/to/output:1"Add +; +path/to/output:0 +path/to/output:1path/to/output:2"Addsimple_scoringZ +first_input + + +Z +second/input:0 + + +Z +third_input + + +b +path/to/output:0 + + +b +path/to/output:1 + + +b +path/to/output:2 + + +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/summary_model.onnx b/config-model/src/test/integration/onnx-model/files/summary_model.onnx new file mode 100644 index 00000000000..f3898205c6a --- /dev/null +++ b/config-model/src/test/integration/onnx-model/files/summary_model.onnx @@ -0,0 +1,34 @@ +create_model.py:í +4 +first_input +second/input:0path/to/output:0"Add +4 +third_input +second/input:0path/to/output:1"Add +; +path/to/output:0 +path/to/output:1path/to/output:2"Addsimple_scoringZ +first_input + + +Z +second/input:0 + + +Z +third_input + + +b +path/to/output:0 + + +b +path/to/output:1 + + +b +path/to/output:2 + + +B
\ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd index 0f0fa694e6f..6e9ba356293 100644 --- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd +++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd @@ -14,7 +14,7 @@ search test { } onnx-model my_model { - file: files/ranking_model.onnx + file: files/model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": my_function @@ -22,19 +22,25 @@ search test { } onnx-model another_model { - file: files/ranking_model.onnx + file: files/model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": another_function output "path/to/output:2": out } + onnx-model dynamic_model { + file: files/dynamic_model.onnx + input input: my_function + output output: my_output + } + rank-profile test_model_config { function my_function() { expression: tensor(d0[2])(1) } first-phase { - expression: onnxModel(my_model).out + expression: onnxModel(my_model).out{d0:1} } } @@ -49,7 +55,7 @@ search test { expression: my_function() } first-phase { - expression: onnxModel("files/ranking_model.onnx", "path/to/output:1") + expression: onnxModel("files/model.onnx", "path/to/output:1"){d0:1} } } @@ -62,9 +68,29 @@ search test { } summary-features { onnxModel(another_model).out - onnxModel("files/ranking_model.onnx", "path/to/output:2") + onnxModel("files/summary_model.onnx", "path/to/output:2") + } + } + + rank-profile test_dynamic_model { + function my_function() { + expression: tensor(d0[1],d1[2])(d1) + } + first-phase { + expression: onnxModel(dynamic_model){d0:0,d1:1} } + } + rank-profile test_dynamic_model_2 { + function my_function_2() { + expression: tensor(d0[1],d1[3])(d1) + } + function my_function() { + expression: my_function_2() + } + first-phase { + expression: onnxModel(dynamic_model){d0:0,d1:2} + } } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index d9b0c70dfdd..5060aafb55f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -25,43 +25,61 @@ public class RankingExpressionWithOnnxModelTestCase { OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); - assertEquals(3, config.model().size()); + assertEquals(5, config.model().size()); - assertEquals("my_model", config.model(1).name()); + assertEquals("my_model", config.model(0).name()); + assertEquals(3, config.model(0).input().size()); + assertEquals("second/input:0", config.model(0).input(0).name()); + assertEquals("constant(my_constant)", config.model(0).input(0).source()); + assertEquals("first_input", config.model(0).input(1).name()); + assertEquals("attribute(document_field)", config.model(0).input(1).source()); + assertEquals("third_input", config.model(0).input(2).name()); + assertEquals("rankingExpression(my_function)", config.model(0).input(2).source()); + assertEquals(3, config.model(0).output().size()); + assertEquals("path/to/output:0", config.model(0).output(0).name()); + assertEquals("out", config.model(0).output(0).as()); + assertEquals("path/to/output:1", config.model(0).output(1).name()); + assertEquals("path_to_output_1", config.model(0).output(1).as()); + assertEquals("path/to/output:2", config.model(0).output(2).name()); + assertEquals("path_to_output_2", config.model(0).output(2).as()); + + assertEquals("files_model_onnx", config.model(1).name()); assertEquals(3, config.model(1).input().size()); - assertEquals("first_input", config.model(1).input(0).name()); - assertEquals("attribute(document_field)", config.model(1).input(0).source()); - assertEquals("second/input:0", config.model(1).input(1).name()); - assertEquals("constant(my_constant)", config.model(1).input(1).source()); - assertEquals("third_input", config.model(1).input(2).name()); - assertEquals("rankingExpression(my_function)", config.model(1).input(2).source()); - assertEquals(1, config.model(1).output().size()); + assertEquals(3, config.model(1).output().size()); assertEquals("path/to/output:0", config.model(1).output(0).name()); - assertEquals("out", config.model(1).output(0).as()); - - assertEquals("files_ranking_model_onnx", config.model(0).name()); - assertEquals(0, config.model(0).input().size()); - assertEquals(2, config.model(0).output().size()); - assertEquals("path/to/output:1", config.model(0).output(0).name()); - assertEquals("path_to_output_1", config.model(0).output(0).as()); - assertEquals("path/to/output:2", config.model(0).output(1).name()); - assertEquals("path_to_output_2", config.model(0).output(1).as()); + assertEquals("path_to_output_0", config.model(1).output(0).as()); + assertEquals("path/to/output:1", config.model(1).output(1).name()); + assertEquals("path_to_output_1", config.model(1).output(1).as()); + assertEquals("path/to/output:2", config.model(1).output(2).name()); + assertEquals("path_to_output_2", config.model(1).output(2).as()); + assertEquals("files_model_onnx", config.model(1).name()); assertEquals("another_model", config.model(2).name()); assertEquals("third_input", config.model(2).input(2).name()); assertEquals("rankingExpression(another_function)", config.model(2).input(2).source()); + + assertEquals("files_summary_model_onnx", config.model(3).name()); + assertEquals(3, config.model(3).input().size()); + assertEquals(3, config.model(3).output().size()); + + assertEquals("dynamic_model", config.model(4).name()); + assertEquals(1, config.model(4).input().size()); + assertEquals(1, config.model(4).output().size()); + assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); } private void assertTransformedFeature(DocumentDatabase db) { RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(5, config.rankprofile().size()); + assertEquals(7, config.rankprofile().size()); assertEquals("test_model_config", config.rankprofile(2).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); - assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name()); + assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); assertEquals("test_generated_model_config", config.rankprofile(3).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); @@ -69,16 +87,28 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name()); assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); - assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value()); + assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name()); + assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); assertEquals("test_summary_features", config.rankprofile(4).name()); assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); assertEquals("1", config.rankprofile(4).fef().property(3).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); - assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value()); + assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); - assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value()); + assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value()); + + assertEquals("test_dynamic_model", config.rankprofile(5).name()); + assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); + + assertEquals("test_dynamic_model_2", config.rankprofile(6).name()); + assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); + assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); + } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 6bf69907609..40bf970a313 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -70,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -87,7 +87,7 @@ public class RankingExpressionWithOnnxTestCase { queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("query(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, null, "Placeholder", @@ -99,7 +99,7 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -117,7 +117,7 @@ public class RankingExpressionWithOnnxTestCase { "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor<float>(d0[1],d1[784]) }", "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -129,21 +129,21 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testNestedOnnxReference() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "5 + sum(onnx('mnist_softmax.onnx'))"); + "5 + sum(onnx_vespa('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutput() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutputAndSignature() { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'default.layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -155,7 +155,7 @@ public class RankingExpressionWithOnnxTestCase { new QueryProfileRegistry(), " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: onnx('mnist_softmax.onnx')" + + " expression: onnx_vespa('mnist_softmax.onnx')" + " }\n" + " }"); search.compileRankProfile("my_profile", applicationDir.append("models")); @@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx'): " + + "onnx_vespa('mnist_softmax.onnx'): " + "Model refers input 'Placeholder' of type tensor<float>(d0[1],d1[784]) but this function is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); @@ -175,13 +175,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithWrongFunctionType() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)", - "onnx('mnist_softmax.onnx')"); + "onnx_vespa('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx'): " + + "onnx_vespa('mnist_softmax.onnx'): " + "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[1],d1[784]), " + "but this function returns tensor(d0[1],d5[10])", Exceptions.toMessageString(expected)); @@ -192,13 +192,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceSpecifyingNonExistingOutput() { try { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx', 'y')"); + "onnx_vespa('mnist_softmax.onnx', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx('mnist_softmax.onnx','y'): " + + "onnx_vespa('mnist_softmax.onnx','y'): " + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", Exceptions.toMessageString(expected)); } @@ -207,7 +207,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[1],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')"); + "onnx_vespa('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -218,7 +218,7 @@ public class RankingExpressionWithOnnxTestCase { storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", - "onnx('mnist_softmax.onnx')", + "onnx_vespa('mnist_softmax.onnx')", null, null, "Placeholder", @@ -243,7 +243,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d1[10],d2[784])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('mnist_softmax.onnx')" + + " expression: onnx_vespa('mnist_softmax.onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + @@ -288,7 +288,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + + " expression: onnx_vespa('" + name + ".onnx')" + " }\n" + " }"; final String functionName = "imported_ml_function_" + name + "_exp_output"; @@ -310,7 +310,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor<float>(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx('" + name + ".onnx')" + + " expression: onnx_vespa('" + name + ".onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + |