diff options
Diffstat (limited to 'config-model/src/main/java/com')
12 files changed, 102 insertions, 582 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index b153ff62e7d..4011ce43841 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -159,12 +158,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); } - // A reference to an ONNX model? - Optional<TensorType> onnxFeatureType = onnxFeatureType(reference); - if (onnxFeatureType.isPresent()) { - return onnxFeatureType.get(); - } - // A reference to a feature which returns a tensor? Optional<TensorType> featureTensorType = tensorFeatureType(reference); if (featureTensorType.isPresent()) { @@ -217,26 +210,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement return Optional.of(function); } - private Optional<TensorType> onnxFeatureType(Reference reference) { - if ( ! reference.name().equals("onnxModel")) - return Optional.empty(); - - if ( ! featureTypes.containsKey(reference)) { - String configOrFileName = reference.arguments().expressions().get(0).toString(); - - // Look up standardized format as added in RankProfile - String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); - String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); - - reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); - if ( ! featureTypes.containsKey(reference)) { - throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); - } - } - - return Optional.of(featureTypes.get(reference)); - } - /** * There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet. * This returns the type of those features if this is a reference to either of them, or empty otherwise. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 58213186f78..c2fb2107604 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,22 +2,14 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; -import onnx.Onnx; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; +import java.util.List; import java.util.Objects; -import java.util.Optional; -import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -29,16 +21,16 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; - private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private String defaultOutput = null; - private Map<String, String> inputMap = new HashMap<>(); - private Map<String, String> outputMap = new HashMap<>(); + private List<OnnxNameMapping> inputMap = new ArrayList<>(); + private List<OnnxNameMapping> outputMap = new ArrayList<>(); - private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>(); - private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>(); - private Map<String, TensorType> vespaTypes = new HashMap<>(); + public PathType getPathType() { + return pathType; + } + + private PathType pathType = PathType.FILE; public OnnxModel(String name) { this.name = name; @@ -57,52 +49,21 @@ public class OnnxModel { } public void setUri(String uri) { - throw new IllegalArgumentException("URI for ONNX models are not currently supported"); - } - - public PathType getPathType() { - return pathType; - } - - public void setDefaultOutput(String onnxName) { - Objects.requireNonNull(onnxName, "Name cannot be null"); - this.defaultOutput = onnxName; + Objects.requireNonNull(uri, "uri cannot be null"); + this.path = uri; + this.pathType = PathType.URI; } public void addInputNameMapping(String onnxName, String vespaName) { - addInputNameMapping(onnxName, vespaName, true); - } - - public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - if (overwrite || ! inputMap.containsKey(onnxName)) { - inputMap.put(onnxName, vespaName); - } + this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); } public void addOutputNameMapping(String onnxName, String vespaName) { - addOutputNameMapping(onnxName, vespaName, true); - } - - public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - if (overwrite || ! outputMap.containsKey(onnxName)) { - outputMap.put(onnxName, vespaName); - } - } - - public void addInputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - inputTypes.put(onnxName, type); - } - - public void addOutputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - outputTypes.put(onnxName, type); + this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); } /** Initiate sending of this constant to some services over file distribution */ @@ -115,16 +76,11 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } - public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } - public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } - - public String getDefaultOutput() { - return defaultOutput; - } + public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } + public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } public void validate() { if (path == null || path.isEmpty()) @@ -134,142 +90,23 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - /** - * Return the tensor type for an ONNX model for the given context. - * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output - * type depends on the input types for the given context (rank profile). - */ - public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { - Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); - if (onnxOutputType == null) { - throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); - } - if (containsSymbolicDimensionSizes(onnxOutputType)) { - return getTensorTypeWithSymbolicDimensions(onnxOutputType, context); - } - return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); - } - - private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { - Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context); - if (symbolicSizes.isEmpty()) { - return TensorType.empty; // Context is probably a rank profile not using this ONNX model - } - return typeFrom(onnxOutputType, symbolicSizes); - } - - private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) { - Map<String, Long> symbolicSizes = new HashMap<>(); - for (String onnxInputName : inputTypes.keySet()) { - - Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); - if ( ! containsSymbolicDimensionSizes(onnxType)) { - continue; - } - - Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); - if (vespaType.isEmpty()) { - return Collections.emptyMap(); - } - - var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); - var vespaDimensions = vespaType.get().dimensions(); - if (vespaDimensions.size() != onnxDimensions.size()) { - return Collections.emptyMap(); - } - - for (int i = 0; i < vespaDimensions.size(); ++i) { - if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) { - continue; - } - String symbolicName = onnxDimensions.get(i).getDimParam(); - Long size = vespaDimensions.get(i).size().get(); - if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { - throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " + - "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); - } - symbolicSizes.put(symbolicName, size); - } - } - return symbolicSizes; - } - - private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { - String source = inputMap.get(onnxInputName); - if (source != null) { - // Source is either a simple reference (query/attribute/constant)... - Optional<Reference> reference = Reference.simple(source); - if (reference.isPresent()) { - return Optional.of(context.getType(reference.get())); - } - // ... or a function - ExpressionFunction func = context.getFunction(source); - if (func != null) { - return Optional.of(func.getBody().type(context)); - } - } - return Optional.empty(); // if this context does not contain this input - } - - private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) { - return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue()); - } - - private static TensorType typeFrom(Onnx.TypeProto type) { - return typeFrom(type, null); - } - - private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) { - String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... - Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - long onnxDimensionSize = onnxDimension.getDimValue(); - if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { - onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); - } - if (onnxDimensionSize == 0 && symbolicSizes != null) { - // This is for the case where all symbolic dimensions have - // different names, but can be resolved to a single dimension size. - Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); - if (unknownSizes.size() == 1) { - onnxDimensionSize = unknownSizes.iterator().next(); - } - } - if (onnxDimensionSize <= 0) { - throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + - "ONNX type: " + type + " to Vespa tensor type."); - } - builder.indexed(dimensionName, onnxDimensionSize); - } - return builder.build(); - } + public static class OnnxNameMapping { + private String onnxName; + private String vespaName; - private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT8: return TensorType.Value.FLOAT; - case INT16: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.FLOAT; - case UINT8: return TensorType.Value.FLOAT; - case UINT16: return TensorType.Value.FLOAT; - case UINT32: return TensorType.Value.FLOAT; - case UINT64: return TensorType.Value.FLOAT; - default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); + private OnnxNameMapping(String onnxName, String vespaName) { + this.onnxName = onnxName; + this.vespaName = vespaName; } + public String getOnnxName() { return onnxName; } + public String getVespaName() { return vespaName; } + public void setVespaName(String vespaName) { this.vespaName = vespaName; } } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 96c043bdb34..d309f48d6df 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -18,7 +18,6 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; @@ -159,10 +158,6 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } - private Map<String, OnnxModel> onnxModels() { - return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); - } - private Stream<ImmutableSDField> allFields() { if (search == null) return Stream.empty(); if (allFieldsList == null) { @@ -826,20 +821,6 @@ public class RankProfile implements Cloneable { } } - // Add output types for ONNX models - for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { - String modelName = entry.getKey(); - OnnxModel model = entry.getValue(); - Arguments args = new Arguments(new ReferenceNode(modelName)); - - TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context); - context.setType(new Reference("onnxModel", args, null), defaultOutputType); - - for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { - TensorType type = model.getTensorType(mapping.getKey(), context); - context.setType(new Reference("onnxModel", args, mapping.getValue()), type); - } - } return context; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 22a32c8fd65..84442fedc48 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -126,8 +126,8 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); modelBuilder.name(model.getName()); modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); - model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + model.getInputMap().forEach(mapper -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(mapper.getOnnxName()).source(mapper.getVespaName()))); + model.getOutputMap().forEach(mapper -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(mapper.getOnnxName()).as(mapper.getVespaName()))); builder.model(modelBuilder); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 56a5d539906..87eaaf0387a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -448,10 +448,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<String> functionNames = rankProfile.getFunctions().keySet(); if (functionNames.isEmpty()) return; for (OnnxModel onnxModel: rankProfile.getSearch().onnxModels().asMap().values()) { - for (Map.Entry<String, String> mapping : onnxModel.getInputMap().entrySet()) { - String source = mapping.getValue(); + for (OnnxModel.OnnxNameMapping mapping : onnxModel.getInputMap()) { + String source = mapping.getVespaName(); if (functionNames.contains(source)) { - onnxModel.addInputNameMapping(mapping.getKey(), "rankingExpression(" + source + ")"); + mapping.setVespaName("rankingExpression(" + source + ")"); } } } @@ -462,7 +462,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Set<ReferenceNode> replacedSummaryFeatures = new HashSet<>(); for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); - ReferenceNode replacedNode = (ReferenceNode) OnnxModelTransformer.transformFeature(referenceNode, rankProfile); + ReferenceNode replacedNode = OnnxModelTransformer.transformFeature(referenceNode, rankProfile.getSearch()); if (referenceNode != replacedNode) { replacedSummaryFeatures.add(replacedNode); i.remove(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index d23a8376e7a..ec517768ea9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -38,7 +38,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if ( ! feature.getName().equals("onnx") && ! feature.getName().equals("onnx_vespa")) return feature; + if ( ! feature.getName().equals("onnx")) return feature; try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java index 69cdae10e47..e1ad003e5bd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java @@ -1,36 +1,20 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; -import com.yahoo.path.Path; import com.yahoo.searchdefinition.ImmutableSearch; import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.vespa.model.ml.ConvertedModel; -import com.yahoo.vespa.model.ml.FeatureArguments; -import com.yahoo.vespa.model.ml.ModelName; import java.util.List; /** - * Transforms ONNX model features of the forms: - * - * onnxModel(config_name) - * onnxModel(config_name).output - * onnxModel("path/to/model") - * onnxModel("path/to/model").output - * onnxModel("path/to/model", "path/to/output") - * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused - * - * To the format expected by the backend: - * - * onnxModel(config_name).output + * Transforms instances of the onnxModel ranking feature and generates + * ONNX configuration if necessary. * * @author lesters */ @@ -49,92 +33,85 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { if (context.rankProfile() == null) return feature; if (context.rankProfile().getSearch() == null) return feature; - return transformFeature(feature, context.rankProfile()); + return transformFeature(feature, context.rankProfile().getSearch()); } - public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile rankProfile) { - ImmutableSearch search = rankProfile.getSearch(); - final String featureName = feature.getName(); - if ( ! featureName.equals("onnxModel")) return feature; + public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) { + if (!feature.getName().equals("onnxModel")) return feature; Arguments arguments = feature.getArguments(); if (arguments.isEmpty()) - throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a " + - "onnx-model config or an ONNX file."); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments."); - - // Check that the model configuration "onnx-model" exists. If not defined, it should have been added - // by the "OnnxModelConfigGenerator" processor. If it still doesn't exist, it is because we can't find - // the actual ONNX file, which can happen if we are restarting or upgrading an application using an - // ONNX file that was transformed to Vespa ranking expressions. We then assume it is in the model store. - - String modelConfigName = getModelConfigName(feature.reference()); - OnnxModel onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { + throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a " + + "onnx-model config or a ONNX file."); + if (arguments.expressions().size() > 2) + throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments."); + + // Validation that the file actually exists is handled when the file is added to file distribution. + // Validation of inputs, outputs and corresponding types are currently handled by RankSetupValidator. + + String modelConfigName; + OnnxModel onnxModel; + if (arguments.expressions().get(0) instanceof ReferenceNode) { + modelConfigName = arguments.expressions().get(0).toString(); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found"); + } + } else if (arguments.expressions().get(0) instanceof ConstantNode) { String path = asString(arguments.expressions().get(0)); - ModelName modelName = new ModelName(null, Path.fromString(path), true); - ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, path, rankProfile); - FeatureArguments featureArguments = new FeatureArguments(arguments); - return convertedModel.expression(featureArguments, null); - } - - String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput()); - String output = getModelOutput(feature.reference(), defaultOutput); - if (! onnxModel.getOutputMap().containsValue(output)) { - throw new IllegalArgumentException(featureName + " argument '" + output + - "' output not found in model '" + onnxModel.getFileName() + "'"); + modelConfigName = asValidIdentifier(path); + onnxModel = search.onnxModels().get(modelConfigName); + if (onnxModel == null) { + onnxModel = new OnnxModel(modelConfigName, path); + search.onnxModels().add(onnxModel); + } + } else { + throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'"); } - return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); - } - public static String getModelConfigName(Reference reference) { - if (reference.arguments().size() > 0) { - ExpressionNode expr = reference.arguments().expressions().get(0); - if (expr instanceof ReferenceNode) { // refers to onnx-model config - return expr.toString(); + String output = null; + if (feature.getOutput() != null) { + output = feature.getOutput(); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(output, output); } - if (expr instanceof ConstantNode) { // refers to an file path - return asValidIdentifier(expr); + } else if (arguments.expressions().size() > 1) { + String name = asString(arguments.expressions().get(1)); + output = asValidIdentifier(name); + if ( ! hasOutputMapping(onnxModel, output)) { + onnxModel.addOutputNameMapping(name, output); } } - return null; - } - public static String getModelOutput(Reference reference, String defaultOutput) { - if (reference.output() != null) { - return reference.output(); - } else if (reference.arguments().expressions().size() == 2) { - return asValidIdentifier(reference.arguments().expressions().get(1)); - } else if (reference.arguments().expressions().size() > 2) { - return asValidIdentifier(reference.arguments().expressions().get(2)); - } - return defaultOutput; + // Replace feature with name of config + ExpressionNode argument = new ReferenceNode(modelConfigName); + return new ReferenceNode("onnxModel", List.of(argument), output); + } - public static String stripQuotes(String s) { - if (isNotQuoteSign(s.codePointAt(0))) return s; - if (isNotQuoteSign(s.codePointAt(s.length() - 1))) - throw new IllegalArgumentException("argument [" + s + "] is missing end quote"); - return s.substring(1, s.length()-1); + private static boolean hasOutputMapping(OnnxModel onnxModel, String as) { + return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as)); } - public static String asValidIdentifier(String str) { - return str.replaceAll("[^\\w\\d\\$@_]", "_"); + private static String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } - private static String asValidIdentifier(ExpressionNode node) { - return asValidIdentifier(asString(node)); + private static String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); } - private static boolean isNotQuoteSign(int c) { - return c != '\'' && c != '"'; + private static boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; } - public static String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); + private static String asValidIdentifier(String str) { + return str.replaceAll("[^\\w\\d\\$@_]", "_"); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java deleted file mode 100644 index afba88c135d..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchdefinition.processing; - -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.application.api.DeployLogger; -import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.Search; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; -import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.vespa.model.container.search.QueryProfiles; - -import java.util.Map; - -/** - * Processes ONNX ranking features of the form: - * - * onnx("files/model.onnx", "path/to/output:1") - * - * And generates an "onnx-model" configuration as if it was defined in the schema: - * - * onnx-model files_model_onnx { - * file: "files/model.onnx" - * } - * - * Inputs and outputs are resolved in OnnxModelTypeResolver, which must be - * processed after this. - * - * @author lesters - */ -public class OnnxModelConfigGenerator extends Processor { - - public OnnxModelConfigGenerator(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { - super(search, deployLogger, rankProfileRegistry, queryProfiles); - } - - @Override - public void process(boolean validate, boolean documentsOnly) { - if (documentsOnly) return; - for (RankProfile profile : rankProfileRegistry.rankProfilesOf(search)) { - if (profile.getFirstPhaseRanking() != null) { - process(profile.getFirstPhaseRanking().getRoot()); - } - if (profile.getSecondPhaseRanking() != null) { - process(profile.getSecondPhaseRanking().getRoot()); - } - for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { - process(function.getValue().function().getBody().getRoot()); - } - for (ReferenceNode feature : profile.getSummaryFeatures()) { - process(feature); - } - } - } - - private void process(ExpressionNode node) { - if (node instanceof ReferenceNode) { - process((ReferenceNode)node); - } else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode) node).children()) { - process(child); - } - } - } - - private void process(ReferenceNode feature) { - if (feature.getName().equals("onnxModel") || feature.getName().equals("onnx")) { - if (feature.getArguments().size() > 0) { - if (feature.getArguments().expressions().get(0) instanceof ConstantNode) { - ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0); - String path = OnnxModelTransformer.stripQuotes(node.sourceString()); - String modelConfigName = OnnxModelTransformer.asValidIdentifier(path); - - // Only add the configuration if the model can actually be found. - if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { - path = ApplicationPackage.MODELS_DIR.append(path).toString(); - if ( ! OnnxModelTypeResolver.modelFileExists(path, search.applicationPackage())) { - return; - } - } - - OnnxModel onnxModel = search.onnxModels().get(modelConfigName); - if (onnxModel == null) { - onnxModel = new OnnxModel(modelConfigName, path); - search.onnxModels().add(onnxModel); - } - } - } - } - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java deleted file mode 100644 index bead2e7e7c9..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchdefinition.processing; - -import com.yahoo.cloud.config.ConfigserverConfig; -import com.yahoo.component.Version; -import com.yahoo.config.FileReference; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.application.api.DeployLogger; -import com.yahoo.config.application.api.FileRegistry; -import com.yahoo.path.Path; -import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.RankProfileRegistry; -import com.yahoo.searchdefinition.Search; -import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer; -import com.yahoo.vespa.defaults.Defaults; -import com.yahoo.vespa.model.container.search.QueryProfiles; -import onnx.Onnx; - -import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Paths; -import java.util.Map; -import java.util.Optional; - -/** - * Processes every "onnx-model" element in the schema. Parses the model file, - * adds missing input and output mappings (assigning default names), and - * adds tensor types to all model inputs and outputs. - * - * Must be processed before RankingExpressingTypeResolver. - * - * @author lesters - */ -public class OnnxModelTypeResolver extends Processor { - - public OnnxModelTypeResolver(Search search, DeployLogger deployLogger, RankProfileRegistry rankProfileRegistry, QueryProfiles queryProfiles) { - super(search, deployLogger, rankProfileRegistry, queryProfiles); - } - - @Override - public void process(boolean validate, boolean documentsOnly) { - if (documentsOnly) return; - - for (Map.Entry<String, OnnxModel> entry : search.onnxModels().asMap().entrySet()) { - OnnxModel modelConfig = entry.getValue(); - try (InputStream inputStream = openModelFile(modelConfig.getFilePath())) { - Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - - // Model inputs - if not defined, assumes a function is provided with a valid name - for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { - String onnxInputName = valueInfo.getName(); - String vespaInputName = OnnxModelTransformer.asValidIdentifier(onnxInputName); - modelConfig.addInputNameMapping(onnxInputName, vespaInputName, false); - modelConfig.addInputType(onnxInputName, valueInfo.getType()); - } - - // Model outputs - for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) { - String onnxOutputName = valueInfo.getName(); - String vespaOutputName = OnnxModelTransformer.asValidIdentifier(onnxOutputName); - modelConfig.addOutputNameMapping(onnxOutputName, vespaOutputName, false); - modelConfig.addOutputType(onnxOutputName, valueInfo.getType()); - } - - // Set the first output as default - if ( ! model.getGraph().getOutputList().isEmpty()) { - modelConfig.setDefaultOutput(model.getGraph().getOutput(0).getName()); - } - - } catch (IOException e) { - throw new IllegalArgumentException("Unable to parse ONNX model", e); - } - } - } - - static boolean modelFileExists(String path, ApplicationPackage app) { - Path pathInApplicationPackage = Path.fromString(path); - if (getFile(pathInApplicationPackage, app).exists()) { - return true; - } - if (getFileReference(pathInApplicationPackage, app).isPresent()) { - return true; - } - return false; - } - - private InputStream openModelFile(Path path) throws FileNotFoundException { - ApplicationFile file; - Optional<FileReference> reference; - Path modelsPath = ApplicationPackage.MODELS_DIR.append(path); - - if ((file = getFile(path)).exists()) { - return file.createInputStream(); - } - if ((file = getFile(modelsPath)).exists()) { - return file.createInputStream(); - } - if ((reference = getFileReference(path)).isPresent()) { - return openFromFileRepository(path, reference.get()); - } - if ((reference = getFileReference(modelsPath)).isPresent()) { - return openFromFileRepository(modelsPath, reference.get()); - } - - throw new IllegalArgumentException("Unable to find ONNX model file \"" + path + "\" " + - "in application package or file repository."); - } - - private ApplicationFile getFile(Path path) { - return getFile(path, search.applicationPackage()); - } - - private static ApplicationFile getFile(Path path, ApplicationPackage app) { - return app.getFile(path); - } - - private static InputStream openFromFileRepository(Path path, FileReference reference) throws FileNotFoundException { - return new FileInputStream(new File(getFileRepositoryPath(path, reference.value()))); - } - - public static String getFileRepositoryPath(Path path, String fileReference) { - ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults - String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); - return Paths.get(fileRefDir, fileReference, path.getName()).toString(); - } - - private Optional<FileReference> getFileReference(Path path) { - return getFileReference(path, search.applicationPackage()); - } - - private static Optional<FileReference> getFileReference(Path path, ApplicationPackage app) { - Optional<FileRegistry> fileRegistry = getLatestFileRegistry(app); - if (fileRegistry.isPresent()) { - for (FileRegistry.Entry file : fileRegistry.get().export()) { - if (file.relativePath.equals(path.toString())) { - return Optional.of(file.reference); - } - } - } - return Optional.empty(); - } - - private static Optional<FileRegistry> getLatestFileRegistry(ApplicationPackage app) { - if (app == null) return Optional.empty(); - Optional<Version> latest = app.getFileRegistries().keySet().stream().max(Version::compareTo); - return latest.isEmpty() ? Optional.empty() : Optional.of(app.getFileRegistries().get(latest.get())); - } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index 1a3ef9e54b4..e8594c2a87f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -74,8 +74,6 @@ public class Processing { ReferenceFieldsProcessor::new, FastAccessValidator::new, ReservedFunctionNames::new, - OnnxModelConfigGenerator::new, - OnnxModelTypeResolver::new, RankingExpressionTypeResolver::new, // These should be last: IndexingValidation::new, diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index d5c5183b01f..c6c7969e466 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -1,19 +1,20 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation; +import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.IOUtils; import com.yahoo.log.InvalidLogFormatException; import java.util.logging.Level; import com.yahoo.log.LogMessage; import com.yahoo.searchdefinition.OnnxModel; -import com.yahoo.searchdefinition.processing.OnnxModelTypeResolver; import com.yahoo.yolean.Exceptions; import com.yahoo.system.ProcessExecuter; import com.yahoo.text.StringUtilities; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.collections.Pair; import com.yahoo.config.ConfigInstance; +import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.config.search.ImportedFieldsConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -30,6 +31,7 @@ import com.yahoo.vespa.model.search.SearchCluster; import java.io.File; import java.io.IOException; import java.nio.file.Files; +import java.nio.file.Paths; import java.time.Duration; import java.time.Instant; import java.util.logging.Logger; @@ -150,9 +152,12 @@ public class RankSetupValidator extends Validator { // Assist verify-ranksetup in finding the actual ONNX model files Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); if (models.values().size() > 0) { + ConfigserverConfig cfg = new ConfigserverConfig(new ConfigserverConfig.Builder()); // assume defaults + String fileRefDir = Defaults.getDefaults().underVespaHome(cfg.fileReferencesDir()); List<String> config = new ArrayList<>(models.values().size() * 2); for (OnnxModel model : models.values()) { - String modelPath = OnnxModelTypeResolver.getFileRepositoryPath(model.getFilePath(), model.getFileReference()); + String modelFilename = Paths.get(model.getFileName()).getFileName().toString(); + String modelPath = Paths.get(fileRefDir, model.getFileReference(), modelFilename).toString(); config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 5ee6ed02e61..943fcbf6c1d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -150,7 +150,7 @@ public class ConvertedModel { */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { ExpressionFunction expression = selectExpression(arguments); - if (sourceModel.isPresent() && context != null) // we should verify + if (sourceModel.isPresent()) // we should verify verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); return expression.getBody().getRoot(); } |