summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2020-10-27 15:23:17 +0100
committerGitHub <noreply@github.com>2020-10-27 15:23:17 +0100
commitde5e15528fca545b2a9ccbb1386c5590e11fa383 (patch)
tree7d4670986c2a8a8c097a29b575a4bb8d3bad4f87 /config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
parent4d64cbe4e531cb7be1061f1f54809d1d0a1b0061 (diff)
Revert "Lesters/revert revert resolve onnx model types"
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java228
1 files changed, 28 insertions, 200 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 5e8b8579ee6..c2fb2107604 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -2,22 +2,14 @@
package com.yahoo.searchdefinition;
import com.yahoo.config.FileReference;
-import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
-import onnx.Onnx;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
+import java.util.List;
import java.util.Objects;
-import java.util.Optional;
-import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -29,16 +21,16 @@ public class OnnxModel {
public enum PathType {FILE, URI};
private final String name;
- private PathType pathType = PathType.FILE;
private String path = null;
private String fileReference = "";
- private String defaultOutput = null;
- private Map<String, String> inputMap = new HashMap<>();
- private Map<String, String> outputMap = new HashMap<>();
+ private List<OnnxNameMapping> inputMap = new ArrayList<>();
+ private List<OnnxNameMapping> outputMap = new ArrayList<>();
+
+ public PathType getPathType() {
+ return pathType;
+ }
- private Map<String, Onnx.TypeProto> inputTypes = new HashMap<>();
- private Map<String, Onnx.TypeProto> outputTypes = new HashMap<>();
- private Map<String, TensorType> vespaTypes = new HashMap<>();
+ private PathType pathType = PathType.FILE;
public OnnxModel(String name) {
this.name = name;
@@ -57,52 +49,21 @@ public class OnnxModel {
}
public void setUri(String uri) {
- throw new IllegalArgumentException("URI for ONNX models are not currently supported");
- }
-
- public PathType getPathType() {
- return pathType;
- }
-
- public void setDefaultOutput(String onnxName) {
- Objects.requireNonNull(onnxName, "Name cannot be null");
- this.defaultOutput = onnxName;
+ Objects.requireNonNull(uri, "uri cannot be null");
+ this.path = uri;
+ this.pathType = PathType.URI;
}
public void addInputNameMapping(String onnxName, String vespaName) {
- addInputNameMapping(onnxName, vespaName, true);
- }
-
- public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! inputMap.containsKey(onnxName)) {
- inputMap.put(onnxName, vespaName);
- }
+ this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
public void addOutputNameMapping(String onnxName, String vespaName) {
- addOutputNameMapping(onnxName, vespaName, true);
- }
-
- public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
Objects.requireNonNull(onnxName, "Onnx name cannot be null");
Objects.requireNonNull(vespaName, "Vespa name cannot be null");
- if (overwrite || ! outputMap.containsKey(onnxName)) {
- outputMap.put(onnxName, vespaName);
- }
- }
-
- public void addInputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- inputTypes.put(onnxName, type);
- }
-
- public void addOutputType(String onnxName, Onnx.TypeProto type) {
- Objects.requireNonNull(onnxName, "Onnx name cannot be null");
- Objects.requireNonNull(type, "Tensor type cannot be null");
- outputTypes.put(onnxName, type);
+ this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
}
/** Initiate sending of this constant to some services over file distribution */
@@ -115,16 +76,11 @@ public class OnnxModel {
public String getName() { return name; }
public String getFileName() { return path; }
- public Path getFilePath() { return Path.fromString(path); }
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
- public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
- public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
-
- public String getDefaultOutput() {
- return defaultOutput;
- }
+ public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
+ public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
public void validate() {
if (path == null || path.isEmpty())
@@ -134,151 +90,23 @@ public class OnnxModel {
public String toString() {
StringBuilder b = new StringBuilder();
b.append("onnx-model '").append(name)
- .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
- .append("' with ref '").append(fileReference)
- .append("'");
+ .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path)
+ .append("' with ref '").append(fileReference)
+ .append("'");
return b.toString();
}
- /**
- * Return the tensor type for an ONNX model for the given context.
- * An ONNX model can have dynamic/symbolic dimension sizes. If so, the output
- * type depends on the input types for the given context (rank profile).
- */
- public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
- Onnx.TypeProto onnxOutputType = outputTypes.get(onnxName);
- if (onnxOutputType == null) {
- throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' " + "in '" + name + "'");
- }
- if (allDimensionSizesAreKnown(onnxOutputType)) {
- return vespaTypes.computeIfAbsent(onnxName, v -> typeFrom(onnxOutputType));
- }
- return getTensorTypeWithUnknownDimensions(onnxOutputType, context);
- }
-
- private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) {
- return type.getTensorType().getShape().getDimList().stream().noneMatch(d ->
- (d.hasDimParam() && ! d.hasDimValue()) || d.getDimValue() == -1);
- }
-
- private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
- long unboundSize = 0;
- Map<String, Long> symbolicSizes = new HashMap<>();
-
- for (String onnxInputName : inputTypes.keySet()) {
- Onnx.TypeProto onnxType = inputTypes.get(onnxInputName);
- if (allDimensionSizesAreKnown(onnxType)) {
- continue;
- }
-
- Optional<TensorType> vespaType = resolveInputType(onnxInputName, context);
- if (vespaType.isEmpty()) {
- return TensorType.empty;
- }
-
- var onnxDimensions = onnxType.getTensorType().getShape().getDimList();
- var vespaDimensions = vespaType.get().dimensions();
- if (vespaDimensions.size() != onnxDimensions.size()) {
- return TensorType.empty;
- }
-
- for (int i = 0; i < vespaDimensions.size(); ++i) {
- if (vespaDimensions.get(i).size().isEmpty()) {
- continue;
- }
- Long size = vespaDimensions.get(i).size().get();
-
- // Handle dimensions with size -1 - typically batch dimensions
- if (onnxDimensions.get(i).getDimValue() == -1) {
- if (unboundSize != 0 && unboundSize != size) {
- throw new IllegalArgumentException("Found conflicting sizes for unbound dimension " +
- "for type '" + onnxOutputType + "' in ONNX model '" + name + "'");
- }
- unboundSize = size;
-
- // Handle dimensions with symbolic names
- } else if (onnxDimensions.get(i).hasDimParam()) {
- String symbolicName = onnxDimensions.get(i).getDimParam();
- if (symbolicSizes.containsKey(symbolicName) && ! symbolicSizes.get(symbolicName).equals(size)) {
- throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" +
- symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + name + "'");
- }
- symbolicSizes.put(symbolicName, size);
- }
- }
- }
- return typeFrom(onnxOutputType, symbolicSizes, unboundSize);
- }
-
- private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
- String source = inputMap.get(onnxInputName);
- if (source != null) {
- // Source is either a simple reference (query/attribute/constant)...
- Optional<Reference> reference = Reference.simple(source);
- if (reference.isPresent()) {
- return Optional.of(context.getType(reference.get()));
- }
- // ... or a function
- ExpressionFunction func = context.getFunction(source);
- if (func != null) {
- return Optional.of(func.getBody().type(context));
- }
- }
- return Optional.empty(); // if this context does not contain this input
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type) {
- return typeFrom(type, null, 0);
- }
-
- private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) {
- String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
- Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- TensorType.Builder builder = new TensorType.Builder(toValueType(type.getTensorType().getElemType()));
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
- long onnxDimensionSize = onnxDimension.getDimValue();
- if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
- onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
- }
- if (onnxDimensionSize == 0 && symbolicSizes != null) {
- // This is for the case where all symbolic dimensions have
- // different names, but can be resolved to a single dimension size.
- Set<Long> unknownSizes = new HashSet<>(symbolicSizes.values());
- if (unknownSizes.size() == 1) {
- onnxDimensionSize = unknownSizes.iterator().next();
- }
- }
- if (onnxDimensionSize < 0) {
- onnxDimensionSize = unboundSize;
- }
- if (onnxDimensionSize <= 0) {
- throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from " +
- "ONNX type: " + type + " to Vespa tensor type.");
- }
- builder.indexed(dimensionName, onnxDimensionSize);
- }
- return builder.build();
- }
+ public static class OnnxNameMapping {
+ private String onnxName;
+ private String vespaName;
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT8: return TensorType.Value.FLOAT;
- case INT16: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.FLOAT;
- case UINT8: return TensorType.Value.FLOAT;
- case UINT16: return TensorType.Value.FLOAT;
- case UINT32: return TensorType.Value.FLOAT;
- case UINT64: return TensorType.Value.FLOAT;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
+ private OnnxNameMapping(String onnxName, String vespaName) {
+ this.onnxName = onnxName;
+ this.vespaName = vespaName;
}
+ public String getOnnxName() { return onnxName; }
+ public String getVespaName() { return vespaName; }
+ public void setVespaName(String vespaName) { this.vespaName = vespaName; }
}
}