diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2020-10-27 15:23:17 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-27 15:23:17 +0100 |
commit | de5e15528fca545b2a9ccbb1386c5590e11fa383 (patch) | |
tree | 7d4670986c2a8a8c097a29b575a4bb8d3bad4f87 /config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | |
parent | 4d64cbe4e531cb7be1061f1f54809d1d0a1b0061 (diff) |
Revert "Lesters/revert revert resolve onnx model types"
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | 228 |
1 files changed, 28 insertions, 200 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 5e8b8579ee6..c2fb2107604 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -2,22 +2,14 @@ package com.yahoo.searchdefinition; import com.yahoo.config.FileReference; -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; -import onnx.Onnx; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; +import java.util.List; import java.util.Objects; -import java.util.Optional; -import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -29,16 +21,16 @@ public class OnnxModel { public enum PathType {FILE, URI}; private final String name; - private PathType pathType = PathType.FILE; private String path = null; private String fileReference = ""; - private String defaultOutput = null; - private Map<String, String> inputMap = new HashMap<>(); - private Map<String, String> outputMap = new HashMap<>(); + private List<OnnxNameMapping> inputMap = new ArrayList<>(); + private List<OnnxNameMapping> outputMap = new ArrayList<>(); + + public PathType getPathType() { + return pathType; + } - private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>(); - private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>(); - private Map<String, TensorType> vespaTypes = new HashMap<>(); + private PathType pathType = PathType.FILE; public OnnxModel(String name) { this.name = name; @@ -57,52 +49,21 @@ public class OnnxModel { } public void setUri(String uri) { - throw new IllegalArgumentException("URI for ONNX models are not currently supported"); - } - - public PathType getPathType() { - return pathType; - } - - public void setDefaultOutput(String onnxName) { - Objects.requireNonNull(onnxName, "Name cannot be null"); - this.defaultOutput = onnxName; + Objects.requireNonNull(uri, "uri cannot be null"); + this.path = uri; + this.pathType = PathType.URI; } public void addInputNameMapping(String onnxName, String vespaName) { - addInputNameMapping(onnxName, vespaName, true); - } - - public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - if (overwrite || ! inputMap.containsKey(onnxName)) { - inputMap.put(onnxName, vespaName); - } + this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); } public void addOutputNameMapping(String onnxName, String vespaName) { - addOutputNameMapping(onnxName, vespaName, true); - } - - public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) { Objects.requireNonNull(onnxName, "Onnx name cannot be null"); Objects.requireNonNull(vespaName, "Vespa name cannot be null"); - if (overwrite || ! outputMap.containsKey(onnxName)) { - outputMap.put(onnxName, vespaName); - } - } - - public void addInputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - inputTypes.put(onnxName, type); - } - - public void addOutputType(String onnxName, Onnx.TypeProto type) { - Objects.requireNonNull(onnxName, "Onnx name cannot be null"); - Objects.requireNonNull(type, "Tensor type cannot be null"); - outputTypes.put(onnxName, type); + this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); } /** Initiate sending of this constant to some services over file distribution */ @@ -115,16 +76,11 @@ public class OnnxModel { public String getName() { return name; } public String getFileName() { return path; } - public Path getFilePath() { return Path.fromString(path); } public String getUri() { return path; } public String getFileReference() { return fileReference; } - public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } - public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } - - public String getDefaultOutput() { - return defaultOutput; - } + public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } + public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } public void validate() { if (path == null || path.isEmpty()) @@ -134,151 +90,23 @@ public class OnnxModel { public String toString() { StringBuilder b = new StringBuilder(); b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); + .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) + .append("' with ref '").append(fileReference) + .append("'"); return b.toString(); } - /** - * Return the tensor type for an ONNX model for the given context. - * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output - * type depends on the input types for the given context (rank profile). - */ - public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) { - Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName); - if (onnxOutputType == null) { - throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'"); - } - if (allDimensionSizesAreKnown(onnxOutputType)) { - return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType)); - } - return getTensorTypeWithUnknownDimensions(onnxOutputType, context); - } - - private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) { - return type.getTensorType().getShape().getDimList().stream().noneMatch(d -> - (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1); - } - - private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) { - long unboundSize = 0; - Map<String, Long> symbolicSizes = new HashMap<>(); - - for (String onnxInputName : inputTypes.keySet()) { - Onnx.TypeProto onnxType = inputTypes.get(onnxInputName); - if (allDimensionSizesAreKnown(onnxType)) { - continue; - } - - Optional<TensorType> vespaType = resolveInputType(onnxInputName, context); - if (vespaType.isEmpty()) { - return TensorType.empty; - } - - var onnxDimensions = onnxType.getTensorType().getShape().getDimList(); - var vespaDimensions = vespaType.get().dimensions(); - if (vespaDimensions.size() != onnxDimensions.size()) { - return TensorType.empty; - } - - for (int i = 0; i < vespaDimensions.size(); ++i) { - if (vespaDimensions.get(i).size().isEmpty()) { - continue; - } - Long size = vespaDimensions.get(i).size().get(); - - // Handle dimensions with size -1 - typically batch dimensions - if (onnxDimensions.get(i).getDimValue() == -1) { - if (unboundSize != 0 && unboundSize != size) { - throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " + - "for type '" + onnxOutputType + "' in ONNX model '" + name + "'"); - } - unboundSize = size; - - // Handle dimensions with symbolic names - } else if (onnxDimensions.get(i).hasDimParam()) { - String symbolicName = onnxDimensions.get(i).getDimParam(); - if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) { - throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + - symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'"); - } - symbolicSizes.put(symbolicName, size); - } - } - } - return typeFrom(onnxOutputType, symbolicSizes, unboundSize); - } - - private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) { - String source = inputMap.get(onnxInputName); - if (source != null) { - // Source is either a simple reference (query/attribute/constant)... - Optional<Reference> reference = Reference.simple(source); - if (reference.isPresent()) { - return Optional.of(context.getType(reference.get())); - } - // ... or a function - ExpressionFunction func = context.getFunction(source); - if (func != null) { - return Optional.of(func.getBody().type(context)); - } - } - return Optional.empty(); // if this context does not contain this input - } - - private static TensorType typeFrom(Onnx.TypeProto type) { - return typeFrom(type, null, 0); - } - - private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) { - String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... - Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType())); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); - long onnxDimensionSize = onnxDimension.getDimValue(); - if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) { - onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam()); - } - if (onnxDimensionSize == 0 && symbolicSizes != null) { - // This is for the case where all symbolic dimensions have - // different names, but can be resolved to a single dimension size. - Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values()); - if (unknownSizes.size() == 1) { - onnxDimensionSize = unknownSizes.iterator().next(); - } - } - if (onnxDimensionSize < 0) { - onnxDimensionSize = unboundSize; - } - if (onnxDimensionSize <= 0) { - throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " + - "ONNX type: " + type + " to Vespa tensor type."); - } - builder.indexed(dimensionName, onnxDimensionSize); - } - return builder.build(); - } + public static class OnnxNameMapping { + private String onnxName; + private String vespaName; - private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT8: return TensorType.Value.FLOAT; - case INT16: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.FLOAT; - case UINT8: return TensorType.Value.FLOAT; - case UINT16: return TensorType.Value.FLOAT; - case UINT32: return TensorType.Value.FLOAT; - case UINT64: return TensorType.Value.FLOAT; - default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); + private OnnxNameMapping(String onnxName, String vespaName) { + this.onnxName = onnxName; + this.vespaName = vespaName; } + public String getOnnxName() { return onnxName; } + public String getVespaName() { return vespaName; } + public void setVespaName(String vespaName) { this.vespaName = vespaName; } } } |