diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2021-05-27 10:41:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-27 10:41:31 +0200 |
commit | e3d4dbac364216f8d93493d4a5f34835a268fbcf (patch) | |
tree | 90bc2cf28e08123a55854c2db1217f556d349a2e /config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | |
parent | 92efe91ec3d7be1902e7ca9c0e290c7859d535af (diff) | |
parent | 6b6e59869ab5259a8cd2e382cd2b5164a963a293 (diff) |
Merge branch 'master' into lesters/wire-in-stateless-onnx-rt
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | 58 |
1 files changed, 4 insertions, 54 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 3e5726d6d94..3c42987512b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -1,14 +1,9 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; -import com.yahoo.config.FileReference; -import com.yahoo.path.Path; import com.yahoo.tensor.TensorType; -import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.ml.OnnxModelInfo; -import com.yahoo.vespa.model.utils.FileSender; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -19,42 +14,26 @@ import java.util.Objects; * * @author lesters */ -public class OnnxModel { +public class OnnxModel extends DistributableResource { - public enum PathType {FILE, URI}; - - private final String name; - private PathType pathType = PathType.FILE; - private String path = null; - private String fileReference = ""; private OnnxModelInfo modelInfo = null; private Map<String, String> inputMap = new HashMap<>(); private Map<String, String> outputMap = new HashMap<>(); public OnnxModel(String name) { - this.name = name; + super(name); } public OnnxModel(String name, String fileName) { - this(name); - this.path = fileName; + super(name, fileName); validate(); } - public void setFileName(String fileName) { - Objects.requireNonNull(fileName, "Filename cannot be null"); - this.path = fileName; - this.pathType = PathType.FILE; - } - + @Override public void setUri(String uri) { throw new IllegalArgumentException("URI for ONNX models are not currently supported"); } - public PathType getPathType() { - return pathType; - } - public void addInputNameMapping(String onnxName, String vespaName) { addInputNameMapping(onnxName, vespaName, true); } @@ -90,20 +69,6 @@ public class OnnxModel { this.modelInfo = modelInfo; } - /** Initiate sending of this constant to some services over file distribution */ - public void sendTo(Collection<? extends AbstractService> services) { - FileReference reference = (pathType == OnnxModel.PathType.FILE) - ? FileSender.sendFileToServices(path, services) - : FileSender.sendUriToServices(path, services); - this.fileReference = reference.value(); - } - - public String getName() { return name; } - public String getFileName() { return path; } - public Path getFilePath() { return Path.fromString(path); } - public String getUri() { return path; } - public String getFileReference() { return fileReference; } - public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } @@ -114,19 +79,4 @@ public class OnnxModel { TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) { return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty; } - - public void validate() { - if (path == null || path.isEmpty()) - throw new IllegalArgumentException("ONNX models must have a file or uri."); - } - - public String toString() { - StringBuilder b = new StringBuilder(); - b.append("onnx-model '").append(name) - .append(pathType == PathType.FILE ? "' from file '" : " from uri ").append(path) - .append("' with ref '").append(fileReference) - .append("'"); - return b.toString(); - } - } |