diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | 219 |
1 files changed, 191 insertions, 28 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index c2fb2107604..58213186f78 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,14 +2,22 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; +import onnx.Onnx; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,16 +29,16 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; + private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private List<OnnxNameMapping> inputMap = new ArrayList<>(); - private List<OnnxNameMapping> outputMap = new ArrayList<>(); + private String defaultOutput = null; + private Map<String, String> inputMap = new HashMap<>(); + private Map<String, String> outputMap = new HashMap<>(); - public PathType getPathType() { - return pathType; - } - - private PathType pathType = PathType.FILE; + private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>(); + private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>(); + private Map<String, TensorType> vespaTypes = new HashMap<>(); public OnnxModel(String name) { this.name = name; @@ -49,21 +57,52 @@ public class OnnxModel { } public void setUri(String uri) { - Objects.requireNonNull(uri, "uri cannot be null"); - this.path = uri; - this.pathType = PathType.URI; + throw new IllegalArgumentException("URI for ONNX models are not currently supported"); + } + + public PathType getPathType() { + return pathType; + } + + public void setDefaultOutput(String onnxName) { + Objects.requireNonNull(onnxName, "Name cannot be null"); + this.defaultOutput = onnxName; } public void addInputNameMapping(String onnxName, String vespaName) { + addInputNameMapping(onnxName, vespaName, true); + } + + public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! inputMap.containsKey(onnxName)) { + inputMap.put(onnxName, vespaName); + } } public void addOutputNameMapping(String onnxName, String vespaName) { + addOutputNameMapping(onnxName, vespaName, true); + } + + public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + if (overwrite || ! outputMap.containsKey(onnxName)) { + outputMap.put(onnxName, vespaName); + } + } + + public void addInputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + inputTypes.put(onnxName, type); + } + + public void addOutputType(String onnxName, Onnx.TypeProto type) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(type, "Tensor type cannot be null"); + outputTypes.put(onnxName, type); } /** Initiate sending of this constant to some services over file distribution */ @@ -76,11 +115,16 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } + public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } - public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } + public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + + public String getDefaultOutput() { + return defaultOutput; + } public void validate() { if (path == null || path.isEmpty()) @@ -90,23 +134,142 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - public static class OnnxNameMapping { - private String onnxName; - private String vespaName; + /** + * Return the tensor type for an ONNX model for the given context. + * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output + * type depends on the input types for the given context (rank profile). + */ + public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { + Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); + if (onnxOutputType == null) { + throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); + } + if (containsSymbolicDimensionSizes(onnxOutputType)) { + return getTensorTypeWithSymbolicDimensions(onnxOutputType, context); + } + return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); + } + + private TensorType getTensorTypeWithSymbolicDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { + Map<String, Long> symbolicSizes = resolveSymbolicDimensionSizes(context); + if (symbolicSizes.isEmpty()) { + return TensorType.empty; // Context is probably a rank profile not using this ONNX model + } + return typeFrom(onnxOutputType, symbolicSizes); + } + + private Map<String, Long> resolveSymbolicDimensionSizes(MapEvaluationTypeContext context) { + Map<String, Long> symbolicSizes = new HashMap<>(); + for (String onnxInputName : inputTypes.keySet()) { + + Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); + if ( ! containsSymbolicDimensionSizes(onnxType)) { + continue; + } + + Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); + if (vespaType.isEmpty()) { + return Collections.emptyMap(); + } + + var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); + var vespaDimensions = vespaType.get().dimensions(); + if (vespaDimensions.size() != onnxDimensions.size()) { + return Collections.emptyMap(); + } + + for (int i = 0; i < vespaDimensions.size(); ++i) { + if (vespaDimensions.get(i).size().isEmpty() || ! onnxDimensions.get(i).hasDimParam()) { + continue; + } + String symbolicName = onnxDimensions.get(i).getDimParam(); + Long size = vespaDimensions.get(i).size().get(); + if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { + throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension " + + "'" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); + } + symbolicSizes.put(symbolicName, size); + } + } + return symbolicSizes; + } + + private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { + String source = inputMap.get(onnxInputName); + if (source != null) { + // Source is either a simple reference (query/attribute/constant)... + Optional<Reference> reference = Reference.simple(source); + if (reference.isPresent()) { + return Optional.of(context.getType(reference.get())); + } + // ... or a function + ExpressionFunction func = context.getFunction(source); + if (func != null) { + return Optional.of(func.getBody().type(context)); + } + } + return Optional.empty(); // if this context does not contain this input + } + + private static boolean containsSymbolicDimensionSizes(Onnx.TypeProto type) { + return type.getTensorType().getShape().getDimList().stream().anyMatch(d -> d.hasDimParam() && ! d.hasDimValue()); + } + + private static TensorType typeFrom(Onnx.TypeProto type) { + return typeFrom(type, null); + } + + private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + long onnxDimensionSize = onnxDimension.getDimValue(); + if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { + onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); + } + if (onnxDimensionSize == 0 && symbolicSizes != null) { + // This is for the case where all symbolic dimensions have + // different names, but can be resolved to a single dimension size. + Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); + if (unknownSizes.size() == 1) { + onnxDimensionSize = unknownSizes.iterator().next(); + } + } + if (onnxDimensionSize <= 0) { + throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + + "ONNX type: " + type + " to Vespa tensor type."); + } + builder.indexed(dimensionName, onnxDimensionSize); + } + return builder.build(); + } - private OnnxNameMapping(String onnxName, String vespaName) { - this.onnxName = onnxName; - this.vespaName = vespaName; + private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.FLOAT; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); } - public String getOnnxName() { return onnxName; } - public String getVespaName() { return vespaName; } - public void setVespaName(String vespaName) { this.vespaName = vespaName; } } } |