summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java33
1 files changed, 33 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index b7b18887dd8..c2fb2107604 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -5,7 +5,10 @@ import com.yahoo.config.FileReference;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
+import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
import java.util.Objects;
/**
@@ -20,6 +23,8 @@ public class OnnxModel {
private final String name;
private String path = null;
private String fileReference = "";
+ private List<OnnxNameMapping> inputMap = new ArrayList<>();
+ private List<OnnxNameMapping> outputMap = new ArrayList<>();
public PathType getPathType() {
return pathType;
@@ -49,6 +54,18 @@ public class OnnxModel {
this.pathType = PathType.URI;
}
+ public void addInputNameMapping(String onnxName, String vespaName) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ this.inputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ }
+
+ public void addOutputNameMapping(String onnxName, String vespaName) {
+ Objects.requireNonNull(onnxName, "Onnx name cannot be null");
+ Objects.requireNonNull(vespaName, "Vespa name cannot be null");
+ this.outputMap.add(new OnnxNameMapping(onnxName, vespaName));
+ }
+
/** Initiate sending of this constant to some services over file distribution */
public void sendTo(Collection<? extends AbstractService> services) {
FileReference reference = (pathType == OnnxModel.PathType.FILE)
@@ -62,6 +79,9 @@ public class OnnxModel {
public String getUri() { return path; }
public String getFileReference() { return fileReference; }
+ public List<OnnxNameMapping> getInputMap() { return Collections.unmodifiableList(inputMap); }
+ public List<OnnxNameMapping> getOutputMap() { return Collections.unmodifiableList(outputMap); }
+
public void validate() {
if (path == null || path.isEmpty())
throw new IllegalArgumentException("ONNX models must have a file or uri.");
@@ -76,4 +96,17 @@ public class OnnxModel {
return b.toString();
}
+ public static class OnnxNameMapping {
+ private String onnxName;
+ private String vespaName;
+
+ private OnnxNameMapping(String onnxName, String vespaName) {
+ this.onnxName = onnxName;
+ this.vespaName = vespaName;
+ }
+ public String getOnnxName() { return onnxName; }
+ public String getVespaName() { return vespaName; }
+ public void setVespaName(String vespaName) { this.vespaName = vespaName; }
+ }
+
}