From 882d574ab53e8d10a2a8765a64487c20661dc63f Mon Sep 17 00:00:00 2001 From: Arnstein Ressem Date: Sat, 24 Oct 2020 11:20:18 +0200 Subject: Revert "Add type resolving for ONNX models" --- config-model/pom.xml | 9 - .../searchdefinition/MapEvaluationTypeContext.java | 27 -- .../java/com/yahoo/searchdefinition/OnnxModel.java | 219 ++-------- .../com/yahoo/searchdefinition/RankProfile.java | 19 - .../searchdefinition/derived/RankProfileList.java | 4 +- .../searchdefinition/derived/RawRankProfile.java | 8 +- .../expressiontransforms/OnnxFeatureConverter.java | 2 +- .../expressiontransforms/OnnxModelTransformer.java | 141 +++---- .../processing/OnnxModelConfigGenerator.java | 97 ----- .../processing/OnnxModelTypeResolver.java | 154 ------- .../searchdefinition/processing/Processing.java | 2 - .../application/validation/RankSetupValidator.java | 9 +- .../com/yahoo/vespa/model/ml/ConvertedModel.java | 2 +- config-model/src/main/protobuf/onnx.proto | 464 --------------------- .../ml_models/searchdefinitions/test.sd | 2 +- .../onnx-model/files/create_dynamic_model.py | 12 - .../integration/onnx-model/files/create_model.py | 37 -- .../onnx-model/files/dynamic_model.onnx | 13 - .../test/integration/onnx-model/files/model.onnx | 34 -- .../onnx-model/files/summary_model.onnx | 34 -- .../onnx-model/searchdefinitions/test.sd | 36 +- .../RankingExpressionWithOnnxModelTestCase.java | 76 +--- .../RankingExpressionWithOnnxTestCase.java | 36 +- fat-model-dependencies/pom.xml | 5 - 24 files changed, 149 insertions(+), 1293 deletions(-) delete mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java delete mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java delete mode 100644 config-model/src/main/protobuf/onnx.proto delete mode 100755 config-model/src/test/integration/onnx-model/files/create_dynamic_model.py delete mode 100755 config-model/src/test/integration/onnx-model/files/create_model.py delete mode 100644 config-model/src/test/integration/onnx-model/files/dynamic_model.onnx delete mode 100644 config-model/src/test/integration/onnx-model/files/model.onnx delete mode 100644 config-model/src/test/integration/onnx-model/files/summary_model.onnx diff --git a/config-model/pom.xml b/config-model/pom.xml index c0751431d03..95e79fd09fb 100644 --- a/config-model/pom.xml +++ b/config-model/pom.xml @@ -45,11 +45,6 @@ guava-testlib test - - com.google.protobuf - protobuf-java - ${protobuf.version} - com.google.guava guava @@ -503,10 +498,6 @@ true - - com.github.os72 - protoc-jar-maven-plugin - diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index b153ff62e7d..4011ce43841 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -159,12 +158,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); } - // A reference to an ONNX model? - Optional onnxFeatureType = onnxFeatureType(reference); - if (onnxFeatureType.isPresent()) { - return onnxFeatureType.get(); - } - // A reference to a feature which returns a tensor? Optional featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -217,26 +210,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(function); } - private Optional onnxFeatureType(Reference reference) { - if ( ! reference.name().equals("onnxModel")) - return Optional.empty(); - - if ( ! featureTypes.containsKey(reference)) { - String configOrFileName = reference.arguments().expressions().get(0).toString(); - - // Look up standardized format as added in RankProfile - String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); - String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); - - reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); - if ( ! featureTypes.containsKey(reference)) { - throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); - } - } - - return Optional.of(featureTypes.get(reference)); - } - /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 58213186f78..c2fb2107604 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,22 +2,14 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; -import onnx.Onnx; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; +import java.util.List; import java.util.Objects; -import java.util.Optional; -import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -29,16 +21,16 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; - private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private String defaultOutput = null; - private Map inputMap = new HashMap<>(); - private Map outputMap = new HashMap<>(); + private List inputMap = new ArrayList<>(); + private List outputMap = new ArrayList<>(); - private Map inputTypes = new HashMap<>(); - private Map outputTypes = new HashMap<>(); - private Map vespaTypes = new HashMap<>(); + public PathType getPathType() { + return pathType; + } + + private PathType pathType = PathType.FILE; public OnnxModel(String name) { this.name = name; @@ -57,52 +49,21 @@ public class OnnxModel { } public void setUri(String uri) { - throw new IllegalArgumentException("URI for ONNX models are not currently supported"); - } - - public PathType getPathType() { - return pathType; - } - - public void setDefaultOutput(String onnxName) { - Objects.requireNonNull(onnxName, "Name cannot be null"); - this.defaultOutput = onnxName; + Objects.requireNonNull(uri, "uri cannot be null"); + this.path = uri; + this.pathType = PathType.URI; } public void addInputNameMapping(String onnxName, String vespaName) { - addInputNameMapping(onnxName, vespaName, true); - } - - public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - if (overwrite || ! inputMap.containsKey(onnxName)) { - inputMap.put(onnxName, vespaName); - } + this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); } public void addOutputNameMapping(String onnxName, String vespaName) { - addOutputNameMapping(onnxName, vespaName, true); - } - - public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - if (overwrite || ! outputMap.containsKey(onnxName)) { - outputMap.put(onnxName, vespaName); - } - } - - public void addInputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - inputTypes.put(onnxName, type); - } - - public void addOutputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - outputTypes.put(onnxName, type); + this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); } /** Initiate sending of this constant to some services over file distribution */ @@ -115,16 +76,11 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } - public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public Map getInputMap() { return Collections.unmodifiableMap(inputMap); } - public Map getOutputMap() { return Collections.unmodifiableMap(outputMap); } - - public String getDefaultOutput() { - return defaultOutput; - } + public List getInputMap() { return Collections.unmodifiableList(inputMap); } + public List getOutputMap() { return Collections.unmodifiableList(outputMap); } public void validate() { if (path == null || path.isEmpty()) @@ -134,142 +90,23 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - /** - * Return the tensor type for an ONNX model for the given context. - * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output - * type depends on the input types for the given context (rank profile). - */ - public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { - Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); - if (onnxOutputType == null) { - throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); - } - if (containsSymbolicDimensionSizes(onnxOutputType)) { - return getTensorTypeWithSymbolicDimensions(onnxOutputType, context); - } - return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); - } - - private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { - Map symbolicSizes = resolveSymbolicDimensionSizes(context); - if (symbolicSizes.isEmpty()) { - return TensorType.empty; // Context is probably a rank profile not using this ONNX model - } - return typeFrom(onnxOutputType, symbolicSizes); - } - - private Map resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) { - Map symbolicSizes = new HashMap<>(); - for (String onnxInputName : inputTypes.keySet()) { - - Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); - if ( ! containsSymbolicDimensionSizes(onnxType)) { - continue; - } - - Optional vespaType = resolveInputType(onnxInputName, context); - if (vespaType.isEmpty()) { - return Collections.emptyMap(); - } - - var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); - var vespaDimensions = vespaType.get().dimensions(); - if (vespaDimensions.size() != onnxDimensions.size()) { - return Collections.emptyMap(); - } - - for (int i = 0; i < vespaDimensions.size(); ++i) { - if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) { - continue; - } - String symbolicName = onnxDimensions.get(i).getDimParam(); - Long size = vespaDimensions.get(i).size().get(); - if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { - throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " + - "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); - } - symbolicSizes.put(symbolicName, size); - } - } - return symbolicSizes; - } - - private Optional resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { - String source = inputMap.get(onnxInputName); - if (source != null) { - // Source is either a simple reference (query/attribute/constant)... - Optional reference = Reference.simple(source); - if (reference.isPresent()) { - return Optional.of(context.getType(reference.get())); - } - // ... or a function - ExpressionFunction func = context.getFunction(source); - if (func != null) { - return Optional.of(func.getBody().type(context)); - } - } - return Optional.empty(); // if this context does not contain this input - } - - private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) { - return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue()); - } - - private static TensorType typeFrom(Onnx.TypeProto type) { - return typeFrom(type, null); - } - - private static TensorType typeFrom(Onnx.TypeProto type, Map symbolicSizes) { - String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... - Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - long onnxDimensionSize = onnxDimension.getDimValue(); - if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { - onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); - } - if (onnxDimensionSize == 0 && symbolicSizes != null) { - // This is for the case where all symbolic dimensions have - // different names, but can be resolved to a single dimension size. - Set unknownSizes = new HashSet<>(symbolicSizes.values()); - if (unknownSizes.size() == 1) { - onnxDimensionSize = unknownSizes.iterator().next(); - } - } - if (onnxDimensionSize <= 0) { - throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + - "ONNX type: " + type + " to Vespa tensor type."); - } - builder.indexed(dimensionName, onnxDimensionSize); - } - return builder.build(); - } + public static class OnnxNameMapping { + private String onnxName; + private String vespaName; - private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT8: return TensorType.Value.FLOAT; - case INT16: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.FLOAT; - case UINT8: return TensorType.Value.FLOAT; - case UINT16: return TensorType.Value.FLOAT; - case UINT32: return TensorType.Value.FLOAT; - case UINT64: return TensorType.Value.FLOAT; - default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); + private OnnxNameMapping(String onnxName, String vespaName) { + this.onnxName = onnxName; + this.vespaName = vespaName; } + public String getOnnxName() { return onnxName; } + public String getVespaName() { return vespaName; } + public void setVespaName(String vespaName) { this.vespaName = vespaName; } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 96c043bdb34..d309f48d6df 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -18,7 +18,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; @@ -159,10 +158,6 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } - private Map onnxModels() { - return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); - } - private Stream allFields() { if (search == null) return Stream.empty(); if (allFieldsList == null) { @@ -826,20 +821,6 @@ public class RankProfile implements Cloneable { } } - // Add output types for ONNX models - for (Map.Entry entry : onnxModels().entrySet()) { - String modelName = entry.getKey(); - OnnxModel model = entry.getValue(); - Arguments args = new Arguments(new ReferenceNode(modelName)); - - TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context); - context.setType(new Reference("onnxModel", args, null), defaultOutputType); - - for (Map.Entry mapping : model.getOutputMap().entrySet()) { - TensorType type = model.getTensorType(mapping.getKey(), context); - context.setType(new Reference("onnxModel", args, mapping.getValue()), type); - } - } return context; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 22a32c8fd65..84442fedc48 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); modelBuilder.name(model.getName()); modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); - model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); + model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 56a5d539906..87eaaf0387a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { - for (Map.Entry mapping : onnxModel.getInputMap().entrySet()) { - String source = mapping.getValue(); + for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { + String source = mapping.getVespaName(); if (functionNames.contains(source)) { - onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); + mapping.setVespaName("rankingExpression(" + source + ")"); } } } @@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set replacedSummaryFeatures = new HashSet<>(); for (Iterator i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); - ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile); + ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); if (referenceNode != replacedNode) { replacedSummaryFeatures.add(replacedNode); i.remove(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index d23a8376e7a..ec517768ea9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer 3) - throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments."); - - // Check that the model configuration "onnx-model" exists. If not defined, it should have been added - // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find - // the actual ONNX file, which can happen if we are restarting or upgrading an application using an - // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. - - String modelConfigName = getModelConfigName(feature.reference()); - OnnxModel onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { + throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + + "onnx-model config or a ONNX file."); + if (arguments.expressions().size() > 2) + throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); + + // Validation that the file actually exists is handled when the file is added to file distribution. + // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. + + String modelConfigName; + OnnxModel onnxModel; + if (arguments.expressions().get(0) instanceof ReferenceNode) { + modelConfigName = arguments.expressions().get(0).toString(); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); + } + } else if (arguments.expressions().get(0) instanceof ConstantNode) { String path = asString(arguments.expressions().get(0)); - ModelName modelName = new ModelName(null, Path.fromString(path), true); - ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile); - FeatureArguments featureArguments = new FeatureArguments(arguments); - return convertedModel.expression(featureArguments, null); - } - - String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()); - String output = getModelOutput(feature.reference(), defaultOutput); - if (! onnxModel.getOutputMap().containsValue(output)) { - throw new IllegalArgumentException(featureName + " argument '" + output + - "' output not found in model '" + onnxModel.getFileName() + "'"); + modelConfigName = asValidIdentifier(path); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } else { + throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); } - return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); - } - public static String getModelConfigName(Reference reference) { - if (reference.arguments().size() > 0) { - ExpressionNode expr = reference.arguments().expressions().get(0); - if (expr instanceof ReferenceNode) { // refers to onnx-model config - return expr.toString(); + String output = null; + if (feature.getOutput() != null) { + output = feature.getOutput(); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(output, output); } - if (expr instanceof ConstantNode) { // refers to an file path - return asValidIdentifier(expr); + } else if (arguments.expressions().size() > 1) { + String name = asString(arguments.expressions().get(1)); + output = asValidIdentifier(name); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(name, output); } } - return null; - } - public static String getModelOutput(Reference reference, String defaultOutput) { - if (reference.output() != null) { - return reference.output(); - } else if (reference.arguments().expressions().size() == 2) { - return asValidIdentifier(reference.arguments().expressions().get(1)); - } else if (reference.arguments().expressions().size() > 2) { - return asValidIdentifier(reference.arguments().expressions().get(2)); - } - return defaultOutput; + // Replace feature with name of config + ExpressionNode argument = new ReferenceNode(modelConfigName); + return new ReferenceNode("onnxModel", List.of(argument), output); + } - public static String stripQuotes(String s) { - if (isNotQuoteSign(s.codePointAt(0))) return s; - if (isNotQuoteSign(s.codePointAt(s.length() - 1))) - throw new IllegalArgumentException("argument [" + s + "] is missing end quote"); - return s.substring(1, s.length()-1); + private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { + return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); } - public static String asValidIdentifier(String str) { - return str.replaceAll("[^\\w\\d\\$@_]", "_"); + private static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } - private static String asValidIdentifier(ExpressionNode node) { - return asValidIdentifier(asString(node)); + private static String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); } - private static boolean isNotQuoteSign(int c) { - return c != '\'' && c != '"'; + private static boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; } - public static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); + private static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java deleted file mode 100644 index afba88c135d..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchdefinition.processing; - -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.application.api.DeployLogger; -import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.Search; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.vespa.model.container.search.QueryProfiles; - -import java.util.Map; - -/** - * Processes ONNX ranking features of the form: - * - * onnx("files/model.onnx", "path/to/output:1") - * - * And generates an "onnx-model" configuration as if it was defined in the schema: - * - * onnx-model files_model_onnx { - * file: "files/model.onnx" - * } - * - * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be - * processed after this. - * - * @author lesters - */ -public class OnnxModelConfigGenerator extends Processor { - - public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { - super(search, deployLogger, rankProfileRegistry, queryProfiles); - } - - @Override - public void process(boolean validate, boolean documentsOnly) { - if (documentsOnly) return; - for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { - if (profile.getFirstPhaseRanking() != null) { - process(profile.getFirstPhaseRanking().getRoot()); - } - if (profile.getSecondPhaseRanking() != null) { - process(profile.getSecondPhaseRanking().getRoot()); - } - for (Map.Entry function : profile.getFunctions().entrySet()) { - process(function.getValue().function().getBody().getRoot()); - } - for (ReferenceNode feature : profile.getSummaryFeatures()) { - process(feature); - } - } - } - - private void process(ExpressionNode node) { - if (node instanceof ReferenceNode) { - process((ReferenceNode)node); - } else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode) node).children()) { - process(child); - } - } - } - - private void process(ReferenceNode feature) { - if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) { - if (feature.getArguments().size() > 0) { - if (feature.getArguments().expressions().get(0) instanceof ConstantNode) { - ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0); - String path = OnnxModelTransformer.stripQuotes(node.sourceString()); - String modelConfigName = OnnxModelTransformer.asValidIdentifier(path); - - // Only add the configuration if the model can actually be found. - if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { - path = ApplicationPackage.MODELS_DIR.append(path).toString(); - if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { - return; - } - } - - OnnxModel onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - search.onnxModels().add(onnxModel); - } - } - } - } - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java deleted file mode 100644 index bead2e7e7c9..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchdefinition.processing; - -import com.yahoo.cloud.config.ConfigserverConfig; -import com.yahoo.component.Version; -import com.yahoo.config.FileReference; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.application.api.DeployLogger; -import com.yahoo.config.application.api.FileRegistry; -import com.yahoo.path.Path; -import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.Search; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; -import com.yahoo.vespa.defaults.Defaults; -import com.yahoo.vespa.model.container.search.QueryProfiles; -import onnx.Onnx; - -import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Paths; -import java.util.Map; -import java.util.Optional; - -/** - * Processes every "onnx-model" element in the schema. Parses the model file, - * adds missing input and output mappings (assigning default names), and - * adds tensor types to all model inputs and outputs. - * - * Must be processed before RankingExpressingTypeResolver. - * - * @author lesters - */ -public class OnnxModelTypeResolver extends Processor { - - public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { - super(search, deployLogger, rankProfileRegistry, queryProfiles); - } - - @Override - public void process(boolean validate, boolean documentsOnly) { - if (documentsOnly) return; - - for (Map.Entry entry : search.onnxModels().asMap().entrySet()) { - OnnxModel modelConfig = entry.getValue(); - try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) { - Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - - // Model inputs - if not defined, assumes a function is provided with a valid name - for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { - String onnxInputName = valueInfo.getName(); - String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName); - modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false); - modelConfig.addInputType(onnxInputName, valueInfo.getType()); - } - - // Model outputs - for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { - String onnxOutputName = valueInfo.getName(); - String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName); - modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false); - modelConfig.addOutputType(onnxOutputName, valueInfo.getType()); - } - - // Set the first output as default - if ( ! model.getGraph().getOutputList().isEmpty()) { - modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName()); - } - - } catch (IOException e) { - throw new IllegalArgumentException("Unable to parse ONNX model", e); - } - } - } - - static boolean modelFileExists(String path, ApplicationPackage app) { - Path pathInApplicationPackage = Path.fromString(path); - if (getFile(pathInApplicationPackage, app).exists()) { - return true; - } - if (getFileReference(pathInApplicationPackage, app).isPresent()) { - return true; - } - return false; - } - - private InputStream openModelFile(Path path) throws FileNotFoundException { - ApplicationFile file; - Optional reference; - Path modelsPath = ApplicationPackage.MODELS_DIR.append(path); - - if ((file = getFile(path)).exists()) { - return file.createInputStream(); - } - if ((file = getFile(modelsPath)).exists()) { - return file.createInputStream(); - } - if ((reference = getFileReference(path)).isPresent()) { - return openFromFileRepository(path, reference.get()); - } - if ((reference = getFileReference(modelsPath)).isPresent()) { - return openFromFileRepository(modelsPath, reference.get()); - } - - throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " + - "in application package or file repository."); - } - - private ApplicationFile getFile(Path path) { - return getFile(path, search.applicationPackage()); - } - - private static ApplicationFile getFile(Path path, ApplicationPackage app) { - return app.getFile(path); - } - - private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException { - return new FileInputStream(new File(getFileRepositoryPath(path, reference.value()))); - } - - public static String getFileRepositoryPath(Path path, String fileReference) { - ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults - String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); - return Paths.get(fileRefDir, fileReference, path.getName()).toString(); - } - - private Optional getFileReference(Path path) { - return getFileReference(path, search.applicationPackage()); - } - - private static Optional getFileReference(Path path, ApplicationPackage app) { - Optional fileRegistry = getLatestFileRegistry(app); - if (fileRegistry.isPresent()) { - for (FileRegistry.Entry file : fileRegistry.get().export()) { - if (file.relativePath.equals(path.toString())) { - return Optional.of(file.reference); - } - } - } - return Optional.empty(); - } - - private static Optional getLatestFileRegistry(ApplicationPackage app) { - if (app == null) return Optional.empty(); - Optional latest = app.getFileRegistries().keySet().stream().max(Version::compareTo); - return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get())); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index 1a3ef9e54b4..e8594c2a87f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -74,8 +74,6 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, - OnnxModelConfigGenerator::new, - OnnxModelTypeResolver::new, RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index d5c5183b01f..c6c7969e466 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -1,19 +1,20 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; +import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; import java.util.logging.Level; import com.yahoo.log.LogMessage; import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.collections.Pair; import com.yahoo.config.ConfigInstance; +import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.config.search.ImportedFieldsConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -30,6 +31,7 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.io.File; import java.io.IOException; import java.nio.file.Files; +import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; @@ -150,9 +152,12 @@ public class RankSetupValidator extends Validator { // Assist verify-ranksetup in finding the actual ONNX model files Map models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); if (models.values().size() > 0) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); List config = new ArrayList<>(models.values().size() * 2); for (OnnxModel model : models.values()) { - String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference()); + String modelFilename = Paths.get(model.getFileName()).getFileName().toString(); + String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString(); config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 5ee6ed02e61..943fcbf6c1d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -150,7 +150,7 @@ public class ConvertedModel { */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { ExpressionFunction expression = selectExpression(arguments); - if (sourceModel.isPresent() && context != null) // we should verify + if (sourceModel.isPresent()) // we should verify verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); return expression.getBody().getRoot(); } diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto deleted file mode 100644 index dc6542867e0..00000000000 --- a/config-model/src/main/protobuf/onnx.proto +++ /dev/null @@ -1,464 +0,0 @@ -// -// WARNING: This file is automatically generated! Please edit onnx.in.proto. -// - - -// Copyright (c) Facebook Inc. and Microsoft Corporation. -// Licensed under the MIT license. - -syntax = "proto2"; - -package onnx; - -// Overview -// -// ONNX is an open specification that is comprised of the following components: -// -// 1) A definition of an extensible computation graph model. -// 2) Definitions of standard data types. -// 3) Definitions of built-in operators. -// -// This document describes the syntax of models and their computation graphs, -// as well as the standard data types. Together, they are referred to as the ONNX -// Intermediate Representation, or 'IR' for short. -// -// The normative semantic specification of the ONNX IR is found in docs/IR.md. -// Definitions of the built-in neural network operators may be found in docs/Operators.md. - -// Notes -// -// Release -// -// We are still in the very early stage of defining ONNX. The current -// version of ONNX is a starting point. While we are actively working -// towards a complete spec, we would like to get the community involved -// by sharing our working version of ONNX. -// -// Protobuf compatibility -// -// To simplify framework compatibility, ONNX is defined using the subset of protobuf -// that is compatible with both protobuf v2 and v3. This means that we do not use any -// protobuf features that are only available in one of the two versions. -// -// Here are the most notable contortions we have to carry out to work around -// these limitations: -// -// - No 'map' (added protobuf 3.0). We instead represent mappings as lists -// of key-value pairs, where order does not matter and duplicates -// are not allowed. - - -// Versioning -// -// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md -// -// To be compatible with both proto2 and proto3, we will use a version number -// that is not defined by the default value but an explicit enum number. -enum Version { - // proto3 requires the first enum value to be zero. - // We add this just to appease the compiler. - _START_VERSION = 0; - // The version field is always serialized and we will use it to store the - // version that the graph is generated from. This helps us set up version - // control. We should use version as - // xx(major) - xx(minor) - xxxx(bugfix) - // and we are starting with 0x00000001 (0.0.1), which was the - // version we published on Oct 10, 2017. - IR_VERSION_2017_10_10 = 0x00000001; - - // IR_VERSION 0.0.2 published on Oct 30, 2017 - // - Added type discriminator to AttributeProto to support proto3 users - IR_VERSION_2017_10_30 = 0x00000002; - - // IR VERSION 0.0.3 published on Nov 3, 2017 - // - For operator versioning: - // - Added new message OperatorSetIdProto - // - Added opset_import in ModelProto - // - For vendor extensions, added domain in NodeProto - IR_VERSION = 0x00000003; -} - -// Attributes -// -// A named attribute containing either singular float, integer, string, graph, -// and tensor values, or repeated float, integer, string, graph, and tensor values. -// An AttributeProto MUST contain the name field, and *only one* of the -// following content fields, effectively enforcing a C/C++ union equivalent. -message AttributeProto { - - // Note: this enum is structurally identical to the OpSchema::AttrType - // enum defined in schema.h. If you rev one, you likely need to rev the other. - enum AttributeType { - UNDEFINED = 0; - FLOAT = 1; - INT = 2; - STRING = 3; - TENSOR = 4; - GRAPH = 5; - - FLOATS = 6; - INTS = 7; - STRINGS = 8; - TENSORS = 9; - GRAPHS = 10; - } - - // The name field MUST be present for this version of the IR. - optional string name = 1; // namespace Attribute - - // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. - // In this case, this AttributeProto does not contain data, and it's a reference of attribute - // in parent scope. - // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. - optional string ref_attr_name = 21; - - // A human-readable documentation for this attribute. Markdown is allowed. - optional string doc_string = 13; - - // The type field MUST be present for this version of the IR. - // For 0.0.1 versions of the IR, this field was not defined, and - // implementations needed to use has_field hueristics to determine - // which value field was in use. For IR_VERSION 0.0.2 or later, this - // field MUST be set and match the f|i|s|t|... field in use. This - // change was made to accomodate proto3 implementations. - optional AttributeType type = 20; // discriminator that indicates which field below is in use - - // Exactly ONE of the following fields must be present for this version of the IR - optional float f = 2; // float - optional int64 i = 3; // int - optional bytes s = 4; // UTF-8 string - optional TensorProto t = 5; // tensor value - optional GraphProto g = 6; // graph - // Do not use field below, it's deprecated. - // optional ValueProto v = 12; // value - subsumes everything but graph - - repeated float floats = 7; // list of floats - repeated int64 ints = 8; // list of ints - repeated bytes strings = 9; // list of UTF-8 strings - repeated TensorProto tensors = 10; // list of tensors - repeated GraphProto graphs = 11; // list of graph -} - -// Defines information on value, including the name, the type, and -// the shape of the value. -message ValueInfoProto { - // This field MUST be present in this version of the IR. - optional string name = 1; // namespace Value - // This field MUST be present in this version of the IR. - optional TypeProto type = 2; - // A human-readable documentation for this value. Markdown is allowed. - optional string doc_string = 3; -} - -// Nodes -// -// Computation graphs are made up of a DAG of nodes, which represent what is -// commonly called a "layer" or "pipeline stage" in machine learning frameworks. -// -// For example, it can be a node of type "Conv" that takes in an image, a filter -// tensor and a bias tensor, and produces the convolved output. -message NodeProto { - repeated string input = 1; // namespace Value - repeated string output = 2; // namespace Value - - // An optional identifier for this node in a graph. - // This field MAY be absent in ths version of the IR. - optional string name = 3; // namespace Node - - // The symbolic identifier of the Operator to execute. - optional string op_type = 4; // namespace Operator - // The domain of the OperatorSet that specifies the operator named by op_type. - optional string domain = 7; // namespace Domain - - // Additional named attributes. - repeated AttributeProto attribute = 5; - - // A human-readable documentation for this node. Markdown is allowed. - optional string doc_string = 6; -} - -// Models -// -// ModelProto is a top-level file/container format for bundling a ML model and -// associating its computation graph with metadata. -// -// The semantics of the model are described by the associated GraphProto. -message ModelProto { - // The version of the IR this model targets. See Version enum above. - // This field MUST be present. - optional int64 ir_version = 1; - - // The OperatorSets this model relies on. - // All ModelProtos MUST have at least one entry that - // specifies which version of the ONNX OperatorSet is - // being imported. - // - // All nodes in the ModelProto's graph will bind against the operator - // with the same-domain/same-op_type operator with the HIGHEST version - // in the referenced operator sets. - repeated OperatorSetIdProto opset_import = 8; - - // The name of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - optional string producer_name = 2; - - // The version of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - optional string producer_version = 3; - - // Domain name of the model. - // We use reverse domain names as name space indicators. For example: - // `com.facebook.fair` or `com.microsoft.cognitiveservices` - // - // Together with `model_version` and GraphProto.name, this forms the unique identity of - // the graph. - optional string domain = 4; - - // The version of the graph encoded. See Version enum below. - optional int64 model_version = 5; - - // A human-readable documentation for this model. Markdown is allowed. - optional string doc_string = 6; - - // The parameterized graph that is evaluated to execute the model. - optional GraphProto graph = 7; - - // Named metadata values; keys should be distinct. - repeated StringStringEntryProto metadata_props = 14; -}; - -// StringStringEntryProto follows the pattern for cross-proto-version maps. -// See https://developers.google.com/protocol-buffers/docs/proto3#maps -message StringStringEntryProto { - optional string key = 1; - optional string value= 2; -}; - -// Graphs -// -// A graph defines the computational logic of a model and is comprised of a parameterized -// list of nodes that form a directed acyclic graph based on their inputs and outputs. -// This is the equivalent of the "network" or "graph" in many deep learning -// frameworks. -message GraphProto { - // The nodes in the graph, sorted topologically. - repeated NodeProto node = 1; - - // The name of the graph. - optional string name = 2; // namespace Graph - - // A list of named tensor values, used to specify constant inputs of the graph. - // Each TensorProto entry must have a distinct name (within the list) that - // also appears in the input list. - repeated TensorProto initializer = 5; - - // A human-readable documentation for this graph. Markdown is allowed. - optional string doc_string = 10; - - // The inputs and outputs of the graph. - repeated ValueInfoProto input = 11; - repeated ValueInfoProto output = 12; - - // Information for the values in the graph. The ValueInfoProto.name's - // must be distinct. It is optional for a value to appear in value_info list. - repeated ValueInfoProto value_info = 13; - - // DO NOT USE the following fields, they were deprecated from earlier versions. - // repeated string input = 3; - // repeated string output = 4; - // optional int64 ir_version = 6; - // optional int64 producer_version = 7; - // optional string producer_tag = 8; - // optional string domain = 9; -} - -// Tensors -// -// A serialized tensor value. -message TensorProto { - enum DataType { - UNDEFINED = 0; - // Basic types. - FLOAT = 1; // float - UINT8 = 2; // uint8_t - INT8 = 3; // int8_t - UINT16 = 4; // uint16_t - INT16 = 5; // int16_t - INT32 = 6; // int32_t - INT64 = 7; // int64_t - STRING = 8; // string - BOOL = 9; // bool - - // Advanced types - FLOAT16 = 10; - DOUBLE = 11; - UINT32 = 12; - UINT64 = 13; - COMPLEX64 = 14; // complex with float32 real and imaginary components - COMPLEX128 = 15; // complex with float64 real and imaginary components - // Future extensions go here. - } - - // The shape of the tensor. - repeated int64 dims = 1; - - // The data type of the tensor. - optional DataType data_type = 2; - - // For very large tensors, we may want to store them in chunks, in which - // case the following fields will specify the segment that is stored in - // the current TensorProto. - message Segment { - optional int64 begin = 1; - optional int64 end = 2; - } - optional Segment segment = 3; - - // Tensor content must be organized in row-major order. - // - // Depending on the data_type field, exactly one of the fields below with - // name ending in _data is used to store the elements of the tensor. - - // For float and complex64 values - // Complex64 tensors are encoded as a single array of floats, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. - repeated float float_data = 4 [packed = true]; - - // For int32, uint8, int8, uint16, int16, bool, and float16 values - // float16 values must be bit-wise converted to an uint16_t prior - // to writing to the buffer. - // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 - repeated int32 int32_data = 5 [packed = true]; - - // For strings. - // Each element of string_data is a UTF-8 encoded Unicode - // string. No trailing null, no leading BOM. The protobuf "string" - // scalar type is not used to match ML community conventions. - // When this field is present, the data_type field MUST be STRING - repeated bytes string_data = 6; - - // For int64. - // When this field is present, the data_type field MUST be INT64 - repeated int64 int64_data = 7 [packed = true]; - - // Optionally, a name for the tensor. - optional string name = 8; // namespace Value - - // A human-readable documentation for this tensor. Markdown is allowed. - optional string doc_string = 12; - - // Serializations can either use one of the fields above, or use this - // raw bytes field. The only exception is the string case, where one is - // required to store the content in the repeated bytes string_data field. - // - // When this raw_data field is used to store tensor value, elements MUST - // be stored in as fixed-width, little-endian order. - // Floating-point data types MUST be stored in IEEE 754 format. - // Complex64 elements must be written as two consecutive FLOAT values, real component first. - // Complex128 elements must be written as two consecutive DOUBLE values, real component first. - // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). - // - // Note: the advantage of specific field rather than the raw_data field is - // that in some cases (e.g. int data), protobuf does a better packing via - // variable length storage, and may lead to smaller binary footprint. - // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED - optional bytes raw_data = 9; - - // For double - // Complex64 tensors are encoded as a single array of doubles, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 - repeated double double_data = 10 [packed = true]; - - // For uint64 and uint32 values - // When this field is present, the data_type field MUST be - // UINT32 or UINT64 - repeated uint64 uint64_data = 11 [packed = true]; -} - -// Defines a tensor shape. A dimension can be either an integer value -// or a symbolic variable. A symbolic variable represents an unknown -// dimension. -message TensorShapeProto { - message Dimension { - oneof value { - int64 dim_value = 1; - string dim_param = 2; // namespace Shape - }; - // Standard denotation can optionally be used to denote tensor - // dimensions with standard semantic descriptions to ensure - // that operations are applied to the correct axis of a tensor. - optional string denotation = 3; - }; - repeated Dimension dim = 1; -} - -// A set of pre-defined constants to be used as values for -// the standard denotation field in TensorShapeProto.Dimension -// for semantic description of the tensor dimension. -message DenotationConstProto { - // Describe a batch number dimension. - optional string DATA_BATCH = 1 [default = "DATA_BATCH"]; - // Describe a channel dimension. - optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"]; - // Describe a time dimension. - optional string DATA_TIME = 3 [default = "DATA_TIME"]; - // Describe a feature dimension. This is typically a feature - // dimension in RNN and/or spatial dimension in CNN. - optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"]; - // Describe a filter in-channel dimension. This is the dimension - // that is identical (in size) to the channel dimension of the input - // image feature maps. - optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"]; - // Describe a filter out channel dimension. This is the dimension - // that is identical (int size) to the channel dimension of the output - // image feature maps. - optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"]; - // Describe a filter spatial dimension. - optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"]; -} - -// Types -// -// The standard ONNX data types. -message TypeProto { - - message Tensor { - // This field MUST NOT have the value of UNDEFINED - // This field MUST be present for this version of the IR. - optional TensorProto.DataType elem_type = 1; - optional TensorShapeProto shape = 2; - } - - - oneof value { - // The type of a tensor. - Tensor tensor_type = 1; - - } -} - -// Operator Sets -// -// OperatorSets are uniquely identified by a (domain, opset_version) pair. -message OperatorSetIdProto { - // The domain of the operator set being identified. - // The empty string ("") or absence of this field implies the operator - // set that is defined as part of the ONNX specification. - // This field MUST be present in this version of the IR when referring to any other operator set. - optional string domain = 1; - - // The version of the operator set being identified. - // This field MUST be present in this version of the IR. - optional int64 version = 2; -} \ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd index cc73f2daff5..e9575af6010 100644 --- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd +++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd @@ -18,7 +18,7 @@ search test { } function mnist_softmax_onnx() { - expression: onnx_vespa("mnist_softmax") + expression: onnx("mnist_softmax") } function my_xgboost() { diff --git a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py b/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py deleted file mode 100755 index 55df3a557e9..00000000000 --- a/config-model/src/test/integration/onnx-model/files/create_dynamic_model.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import onnx -from onnx import helper, TensorProto - -INPUT = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "sequence"]) -OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "sequence"]) - -nodes = [helper.make_node('Identity', ['input'], ['output'])] -graph_def = helper.make_graph( nodes, 'simple_scoring', [INPUT], [OUTPUT]) -model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py') -onnx.save(model_def, 'dynamic_model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/create_model.py b/config-model/src/test/integration/onnx-model/files/create_model.py deleted file mode 100755 index 10ff92c2eda..00000000000 --- a/config-model/src/test/integration/onnx-model/files/create_model.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import onnx -from onnx import helper, TensorProto - -INPUT_1 = helper.make_tensor_value_info('first_input', TensorProto.FLOAT, [2]) -INPUT_2 = helper.make_tensor_value_info('second/input:0', TensorProto.FLOAT, [2]) -INPUT_3 = helper.make_tensor_value_info('third_input', TensorProto.FLOAT, [2]) -OUTPUT_1 = helper.make_tensor_value_info('path/to/output:0', TensorProto.FLOAT, [2]) -OUTPUT_2 = helper.make_tensor_value_info('path/to/output:1', TensorProto.FLOAT, [2]) -OUTPUT_3 = helper.make_tensor_value_info('path/to/output:2', TensorProto.FLOAT, [2]) - -nodes = [ - helper.make_node( - 'Add', - ['first_input', 'second/input:0'], - ['path/to/output:0'], - ), - helper.make_node( - 'Add', - ['third_input', 'second/input:0'], - ['path/to/output:1'] - ), - helper.make_node( - 'Add', - ['path/to/output:0', 'path/to/output:1'], - ['path/to/output:2'] - ), -] -graph_def = helper.make_graph( - nodes, - 'simple_scoring', - [INPUT_1, INPUT_2, INPUT_3], - [OUTPUT_1, OUTPUT_2, OUTPUT_3] -) -model_def = helper.make_model(graph_def, producer_name='create_model.py') -onnx.save(model_def, 'model.onnx') diff --git a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx b/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx deleted file mode 100644 index 6bbdad2d76e..00000000000 --- a/config-model/src/test/integration/onnx-model/files/dynamic_model.onnx +++ /dev/null @@ -1,13 +0,0 @@ -create_dynamic_model.py:x - -inputoutput"Identitysimple_scoringZ$ -input - -batch - -sequenceb% -output - -batch - -sequenceB \ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/model.onnx b/config-model/src/test/integration/onnx-model/files/model.onnx deleted file mode 100644 index f3898205c6a..00000000000 --- a/config-model/src/test/integration/onnx-model/files/model.onnx +++ /dev/null @@ -1,34 +0,0 @@ -create_model.py:í -4 - first_input -second/input:0path/to/output:0"Add -4 - third_input -second/input:0path/to/output:1"Add -; -path/to/output:0 -path/to/output:1path/to/output:2"Addsimple_scoringZ - first_input - - -Z -second/input:0 - - -Z - third_input - - -b -path/to/output:0 - - -b -path/to/output:1 - - -b -path/to/output:2 - - -B \ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/files/summary_model.onnx b/config-model/src/test/integration/onnx-model/files/summary_model.onnx deleted file mode 100644 index f3898205c6a..00000000000 --- a/config-model/src/test/integration/onnx-model/files/summary_model.onnx +++ /dev/null @@ -1,34 +0,0 @@ -create_model.py:í -4 - first_input -second/input:0path/to/output:0"Add -4 - third_input -second/input:0path/to/output:1"Add -; -path/to/output:0 -path/to/output:1path/to/output:2"Addsimple_scoringZ - first_input - - -Z -second/input:0 - - -Z - third_input - - -b -path/to/output:0 - - -b -path/to/output:1 - - -b -path/to/output:2 - - -B \ No newline at end of file diff --git a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd index 6e9ba356293..0f0fa694e6f 100644 --- a/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd +++ b/config-model/src/test/integration/onnx-model/searchdefinitions/test.sd @@ -14,7 +14,7 @@ search test { } onnx-model my_model { - file: files/model.onnx + file: files/ranking_model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": my_function @@ -22,25 +22,19 @@ search test { } onnx-model another_model { - file: files/model.onnx + file: files/ranking_model.onnx input first_input: attribute(document_field) input "second/input:0": constant(my_constant) input "third_input": another_function output "path/to/output:2": out } - onnx-model dynamic_model { - file: files/dynamic_model.onnx - input input: my_function - output output: my_output - } - rank-profile test_model_config { function my_function() { expression: tensor(d0[2])(1) } first-phase { - expression: onnxModel(my_model).out{d0:1} + expression: onnxModel(my_model).out } } @@ -55,7 +49,7 @@ search test { expression: my_function() } first-phase { - expression: onnxModel("files/model.onnx", "path/to/output:1"){d0:1} + expression: onnxModel("files/ranking_model.onnx", "path/to/output:1") } } @@ -68,29 +62,9 @@ search test { } summary-features { onnxModel(another_model).out - onnxModel("files/summary_model.onnx", "path/to/output:2") - } - } - - rank-profile test_dynamic_model { - function my_function() { - expression: tensor(d0[1],d1[2])(d1) - } - first-phase { - expression: onnxModel(dynamic_model){d0:0,d1:1} + onnxModel("files/ranking_model.onnx", "path/to/output:2") } - } - rank-profile test_dynamic_model_2 { - function my_function_2() { - expression: tensor(d0[1],d1[3])(d1) - } - function my_function() { - expression: my_function_2() - } - first-phase { - expression: onnxModel(dynamic_model){d0:0,d1:2} - } } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index 5060aafb55f..d9b0c70dfdd 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -25,61 +25,43 @@ public class RankingExpressionWithOnnxModelTestCase { OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); - assertEquals(5, config.model().size()); + assertEquals(3, config.model().size()); - assertEquals("my_model", config.model(0).name()); - assertEquals(3, config.model(0).input().size()); - assertEquals("second/input:0", config.model(0).input(0).name()); - assertEquals("constant(my_constant)", config.model(0).input(0).source()); - assertEquals("first_input", config.model(0).input(1).name()); - assertEquals("attribute(document_field)", config.model(0).input(1).source()); - assertEquals("third_input", config.model(0).input(2).name()); - assertEquals("rankingExpression(my_function)", config.model(0).input(2).source()); - assertEquals(3, config.model(0).output().size()); - assertEquals("path/to/output:0", config.model(0).output(0).name()); - assertEquals("out", config.model(0).output(0).as()); - assertEquals("path/to/output:1", config.model(0).output(1).name()); - assertEquals("path_to_output_1", config.model(0).output(1).as()); - assertEquals("path/to/output:2", config.model(0).output(2).name()); - assertEquals("path_to_output_2", config.model(0).output(2).as()); - - assertEquals("files_model_onnx", config.model(1).name()); + assertEquals("my_model", config.model(1).name()); assertEquals(3, config.model(1).input().size()); - assertEquals(3, config.model(1).output().size()); + assertEquals("first_input", config.model(1).input(0).name()); + assertEquals("attribute(document_field)", config.model(1).input(0).source()); + assertEquals("second/input:0", config.model(1).input(1).name()); + assertEquals("constant(my_constant)", config.model(1).input(1).source()); + assertEquals("third_input", config.model(1).input(2).name()); + assertEquals("rankingExpression(my_function)", config.model(1).input(2).source()); + assertEquals(1, config.model(1).output().size()); assertEquals("path/to/output:0", config.model(1).output(0).name()); - assertEquals("path_to_output_0", config.model(1).output(0).as()); - assertEquals("path/to/output:1", config.model(1).output(1).name()); - assertEquals("path_to_output_1", config.model(1).output(1).as()); - assertEquals("path/to/output:2", config.model(1).output(2).name()); - assertEquals("path_to_output_2", config.model(1).output(2).as()); - assertEquals("files_model_onnx", config.model(1).name()); + assertEquals("out", config.model(1).output(0).as()); + + assertEquals("files_ranking_model_onnx", config.model(0).name()); + assertEquals(0, config.model(0).input().size()); + assertEquals(2, config.model(0).output().size()); + assertEquals("path/to/output:1", config.model(0).output(0).name()); + assertEquals("path_to_output_1", config.model(0).output(0).as()); + assertEquals("path/to/output:2", config.model(0).output(1).name()); + assertEquals("path_to_output_2", config.model(0).output(1).as()); assertEquals("another_model", config.model(2).name()); assertEquals("third_input", config.model(2).input(2).name()); assertEquals("rankingExpression(another_function)", config.model(2).input(2).source()); - - assertEquals("files_summary_model_onnx", config.model(3).name()); - assertEquals(3, config.model(3).input().size()); - assertEquals(3, config.model(3).output().size()); - - assertEquals("dynamic_model", config.model(4).name()); - assertEquals(1, config.model(4).input().size()); - assertEquals(1, config.model(4).output().size()); - assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); } private void assertTransformedFeature(DocumentDatabase db) { RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); ((RankProfilesConfig.Producer) db).getConfig(builder); RankProfilesConfig config = new RankProfilesConfig(builder); - assertEquals(7, config.rankprofile().size()); + assertEquals(5, config.rankprofile().size()); assertEquals("test_model_config", config.rankprofile(2).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(2).fef().property(0).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); - assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value()); - assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name()); - assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); + assertEquals("onnxModel(my_model).out", config.rankprofile(2).fef().property(2).value()); assertEquals("test_generated_model_config", config.rankprofile(3).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); @@ -87,28 +69,16 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("rankingExpression(second_input).rankingScript", config.rankprofile(3).fef().property(4).name()); assertEquals("rankingExpression(third_input).rankingScript", config.rankprofile(3).fef().property(6).name()); assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); - assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value()); - assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name()); - assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); + assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_1", config.rankprofile(3).fef().property(8).value()); assertEquals("test_summary_features", config.rankprofile(4).name()); assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); assertEquals("1", config.rankprofile(4).fef().property(3).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); - assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value()); + assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(4).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); - assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value()); - - assertEquals("test_dynamic_model", config.rankprofile(5).name()); - assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name()); - assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name()); - assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); - - assertEquals("test_dynamic_model_2", config.rankprofile(6).name()); - assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); - assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); - + assertEquals("onnxModel(files_ranking_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 40bf970a313..6bf69907609 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -70,7 +70,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", - "onnx_vespa('mnist_softmax.onnx')", + "onnx('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor(d0[1],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -87,7 +87,7 @@ public class RankingExpressionWithOnnxTestCase { queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("query(mytensor)", - "onnx_vespa('mnist_softmax.onnx')", + "onnx('mnist_softmax.onnx')", null, null, "Placeholder", @@ -99,7 +99,7 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", - "onnx_vespa('mnist_softmax.onnx')", + "onnx('mnist_softmax.onnx')", null, "field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -117,7 +117,7 @@ public class RankingExpressionWithOnnxTestCase { ""; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", - "onnx_vespa('mnist_softmax.onnx')", + "onnx('mnist_softmax.onnx')", "constant mytensor { file: ignored\ntype: tensor(d0[1],d1[784]) }", "field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }", "Placeholder", @@ -129,21 +129,21 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testNestedOnnxReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "5 + sum(onnx_vespa('mnist_softmax.onnx'))"); + "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutput() { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx', 'layer_add')"); + "onnx('mnist_softmax.onnx', 'layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testOnnxReferenceWithSpecifiedOutputAndSignature() { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')"); + "onnx('mnist_softmax.onnx', 'default.layer_add')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -155,7 +155,7 @@ public class RankingExpressionWithOnnxTestCase { new QueryProfileRegistry(), " rank-profile my_profile {\n" + " first-phase {\n" + - " expression: onnx_vespa('mnist_softmax.onnx')" + + " expression: onnx('mnist_softmax.onnx')" + " }\n" + " }"); search.compileRankProfile("my_profile", applicationDir.append("models")); @@ -164,7 +164,7 @@ public class RankingExpressionWithOnnxTestCase { } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx_vespa('mnist_softmax.onnx'): " + + "onnx('mnist_softmax.onnx'): " + "Model refers input 'Placeholder' of type tensor(d0[1],d1[784]) but this function is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); @@ -175,13 +175,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithWrongFunctionType() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d5[10])(0.0)", - "onnx_vespa('mnist_softmax.onnx')"); + "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx_vespa('mnist_softmax.onnx'): " + + "onnx('mnist_softmax.onnx'): " + "Model refers input 'Placeholder'. The required type of this is tensor(d0[1],d1[784]), " + "but this function returns tensor(d0[1],d5[10])", Exceptions.toMessageString(expected)); @@ -192,13 +192,13 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceSpecifyingNonExistingOutput() { try { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx', 'y')"); + "onnx('mnist_softmax.onnx', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + - "onnx_vespa('mnist_softmax.onnx','y'): " + + "onnx('mnist_softmax.onnx','y'): " + "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: default.layer_add", Exceptions.toMessageString(expected)); } @@ -207,7 +207,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx')"); + "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -218,7 +218,7 @@ public class RankingExpressionWithOnnxTestCase { storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx')", + "onnx('mnist_softmax.onnx')", null, null, "Placeholder", @@ -243,7 +243,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor(d1[10],d2[784])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx_vespa('mnist_softmax.onnx')" + + " expression: onnx('mnist_softmax.onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + @@ -288,7 +288,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx_vespa('" + name + ".onnx')" + + " expression: onnx('" + name + ".onnx')" + " }\n" + " }"; final String functionName = "imported_ml_function_" + name + "_exp_output"; @@ -310,7 +310,7 @@ public class RankingExpressionWithOnnxTestCase { " expression: tensor(d0[3])(0.0)\n" + " }\n" + " first-phase {\n" + - " expression: onnx_vespa('" + name + ".onnx')" + + " expression: onnx('" + name + ".onnx')" + " }\n" + " }" + " rank-profile my_profile_child inherits my_profile {\n" + diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index 181ef6dffbd..4beaf6086a6 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -221,10 +221,5 @@ jdisc_http_service ${project.version} - - com.google.protobuf - protobuf-java - ${protobuf.version} - -- cgit v1.2.3