aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-27 10:36:21 +0200
committerLester Solbakken <lesters@oath.com>2021-05-27 10:36:21 +0200
commit92efe91ec3d7be1902e7ca9c0e290c7859d535af (patch)
tree2eacff24a84b4816b9fb9ddfa50cecb34edbf1bc /config-model
parent8bb0b79db9d224e632a2f49c77418415a5b6b0d4 (diff)
Move adding os inputs and outputs of ONNX models into OnnxModel
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java15
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java14
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java7
4 files changed, 11 insertions, 31 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
index 64338e24a8d..3e5726d6d94 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java
@@ -81,6 +81,12 @@ public class OnnxModel {
public void setModelInfo(OnnxModelInfo modelInfo) {
Objects.requireNonNull(modelInfo, "Onnx model info cannot be null");
+ for (String onnxName : modelInfo.getInputs()) {
+ addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
+ for (String onnxName : modelInfo.getOutputs()) {
+ addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
this.modelInfo = modelInfo;
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
index 5c37b345edf..4d8fba8c603 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java
@@ -28,23 +28,8 @@ public class OnnxModelTypeResolver extends Processor {
@Override
public void process(boolean validate, boolean documentsOnly) {
if (documentsOnly) return;
-
for (OnnxModel onnxModel : search.onnxModels().asMap().values()) {
OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), search.applicationPackage());
-
- // Temporary, to disregard type information when model info is not available
- if (onnxModelInfo == null) {
- continue;
- }
-
- // Add any missing input and output fields that were not specified in the onnx-model configuration
- for (String onnxName : onnxModelInfo.getInputs()) {
- onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
- }
- for (String onnxName : onnxModelInfo.getOutputs()) {
- onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
- }
-
onnxModel.setModelInfo(onnxModelInfo);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index f4cf5e500dd..4dc38c09ab1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -321,7 +321,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
if (path.startsWith(applicationPath)) {
path = path.substring(applicationPath.length() + 1);
}
- loadModelInfo(onnxModels, model.name(), path);
+ loadOnnxModelInfo(onnxModels, model.name(), path);
}
return onnxModels;
}
@@ -329,11 +329,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
private OnnxModels onnxModelInfoFromStore(String modelName) {
OnnxModels onnxModels = new OnnxModels();
String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString();
- loadModelInfo(onnxModels, modelName, path);
+ loadOnnxModelInfo(onnxModels, modelName, path);
return onnxModels;
}
- private void loadModelInfo(OnnxModels onnModels, String name, String path) {
+ private void loadOnnxModelInfo(OnnxModels onnxModels, String name, String path) {
boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
if ( ! modelExists) {
path = ApplicationPackage.MODELS_DIR.append(path).toString();
@@ -343,14 +343,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(path, this.applicationPackage);
if (onnxModelInfo.getModelPath() != null) {
OnnxModel onnxModel = new OnnxModel(name, onnxModelInfo.getModelPath());
- for (String onnxName : onnxModelInfo.getInputs()) {
- onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
- }
- for (String onnxName : onnxModelInfo.getOutputs()) {
- onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
- }
onnxModel.setModelInfo(onnxModelInfo);
- onnModels.add(onnxModel);
+ onnxModels.add(onnxModel);
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 1fe7bdfe284..58381fe5c3c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -131,12 +131,7 @@ public class OnnxModelInfo {
if (app.getFile(generatedModelInfoPath(pathInApplicationPackage)).exists()) {
return loadFromGeneratedInfo(pathInApplicationPackage, app);
}
-
- // Temporary:
- return null;
-
- // This is the correct behaviour after we've gotten applications through.
- // throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file");
+ throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file");
}
static public boolean modelExists(String path, ApplicationPackage app) {