diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index b7b18887dd8..c2fb2107604 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -5,7 +5,10 @@ import com.yahoo.config.FileReference; import com.yahoo.vespa.model.AbstractService; import com.yahoo.vespa.model.utils.FileSender; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.List; import java.util.Objects; /** @@ -20,6 +23,8 @@ public class OnnxModel { private final String name; private String path = null; private String fileReference = ""; + private List<OnnxNameMapping> inputMap = new ArrayList<>(); + private List<OnnxNameMapping> outputMap = new ArrayList<>(); public PathType getPathType() { return pathType; @@ -49,6 +54,18 @@ public class OnnxModel { this.pathType = PathType.URI; } + public void addInputNameMapping(String onnxName, String vespaName) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + this.inputMap.add(new OnnxNameMapping(onnxName, vespaName)); + } + + public void addOutputNameMapping(String onnxName, String vespaName) { + Objects.requireNonNull(onnxName, "Onnx name cannot be null"); + Objects.requireNonNull(vespaName, "Vespa name cannot be null"); + this.outputMap.add(new OnnxNameMapping(onnxName, vespaName)); + } + /** Initiate sending of this constant to some services over file distribution */ public void sendTo(Collection<? extends AbstractService> services) { FileReference reference = (pathType == OnnxModel.PathType.FILE) @@ -62,6 +79,9 @@ public class OnnxModel { public String getUri() { return path; } public String getFileReference() { return fileReference; } + public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); } + public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); } + public void validate() { if (path == null || path.isEmpty()) throw new IllegalArgumentException("ONNX models must have a file or uri."); @@ -76,4 +96,17 @@ public class OnnxModel { return b.toString(); } + public static class OnnxNameMapping { + private String onnxName; + private String vespaName; + + private OnnxNameMapping(String onnxName, String vespaName) { + this.onnxName = onnxName; + this.vespaName = vespaName; + } + public String getOnnxName() { return onnxName; } + public String getVespaName() { return vespaName; } + public void setVespaName(String vespaName) { this.vespaName = vespaName; } + } + } |