diff options
author | Lester Solbakken <lesters@oath.com> | 2021-05-27 10:36:21 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-05-27 10:36:21 +0200 |
commit | 92efe91ec3d7be1902e7ca9c0e290c7859d535af (patch) | |
tree | 2eacff24a84b4816b9fb9ddfa50cecb34edbf1bc /config-model | |
parent | 8bb0b79db9d224e632a2f49c77418415a5b6b0d4 (diff) |
Move adding os inputs and outputs of ONNX models into OnnxModel
Diffstat (limited to 'config-model')
4 files changed, 11 insertions, 31 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java index 64338e24a8d..3e5726d6d94 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModel.java @@ -81,6 +81,12 @@ public class OnnxModel { public void setModelInfo(OnnxModelInfo modelInfo) { Objects.requireNonNull(modelInfo, "Onnx model info cannot be null"); + for (String onnxName : modelInfo.getInputs()) { + addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } + for (String onnxName : modelInfo.getOutputs()) { + addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } this.modelInfo = modelInfo; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java index 5c37b345edf..4d8fba8c603 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelTypeResolver.java @@ -28,23 +28,8 @@ public class OnnxModelTypeResolver extends Processor { @Override public void process(boolean validate, boolean documentsOnly) { if (documentsOnly) return; - for (OnnxModel onnxModel : search.onnxModels().asMap().values()) { OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), search.applicationPackage()); - - // Temporary, to disregard type information when model info is not available - if (onnxModelInfo == null) { - continue; - } - - // Add any missing input and output fields that were not specified in the onnx-model configuration - for (String onnxName : onnxModelInfo.getInputs()) { - onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); - } - for (String onnxName : onnxModelInfo.getOutputs()) { - onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); - } - onnxModel.setModelInfo(onnxModelInfo); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index f4cf5e500dd..4dc38c09ab1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -321,7 +321,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri if (path.startsWith(applicationPath)) { path = path.substring(applicationPath.length() + 1); } - loadModelInfo(onnxModels, model.name(), path); + loadOnnxModelInfo(onnxModels, model.name(), path); } return onnxModels; } @@ -329,11 +329,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri private OnnxModels onnxModelInfoFromStore(String modelName) { OnnxModels onnxModels = new OnnxModels(); String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString(); - loadModelInfo(onnxModels, modelName, path); + loadOnnxModelInfo(onnxModels, modelName, path); return onnxModels; } - private void loadModelInfo(OnnxModels onnModels, String name, String path) { + private void loadOnnxModelInfo(OnnxModels onnxModels, String name, String path) { boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); if ( ! modelExists) { path = ApplicationPackage.MODELS_DIR.append(path).toString(); @@ -343,14 +343,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(path, this.applicationPackage); if (onnxModelInfo.getModelPath() != null) { OnnxModel onnxModel = new OnnxModel(name, onnxModelInfo.getModelPath()); - for (String onnxName : onnxModelInfo.getInputs()) { - onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); - } - for (String onnxName : onnxModelInfo.getOutputs()) { - onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); - } onnxModel.setModelInfo(onnxModelInfo); - onnModels.add(onnxModel); + onnxModels.add(onnxModel); } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 1fe7bdfe284..58381fe5c3c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -131,12 +131,7 @@ public class OnnxModelInfo { if (app.getFile(generatedModelInfoPath(pathInApplicationPackage)).exists()) { return loadFromGeneratedInfo(pathInApplicationPackage, app); } - - // Temporary: - return null; - - // This is the correct behaviour after we've gotten applications through. - // throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file"); + throw new IllegalArgumentException("Unable to find ONNX model file or generated ONNX info file"); } static public boolean modelExists(String path, ApplicationPackage app) { |