diff options
4 files changed, 59 insertions, 46 deletions
diff --git a/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx b/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx Binary files differindex a86019bf53a..1733eca2141 100644 --- a/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx +++ b/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java index 369b2cbe42b..4b51b244d2d 100644 --- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java +++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java @@ -60,8 +60,8 @@ public class ContainerModelEvaluationTest { } { - String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.36716532707214355}]}"; - assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default.output/eval?input=" + inputTensor(), expected, jdisc); + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.3671652674674988}]}"; + assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/output/eval?input=" + inputTensor(), expected, jdisc); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 5098796e409..f4cf5e500dd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -281,46 +281,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri serviceClusters.addAll(builder.getClusters(deployState, this)); } - private OnnxModels onnxModelInfoFromSource(ImportedMlModel model) { - OnnxModels onnxModels = new OnnxModels(); - if (model.modelType().equals(ImportedMlModel.ModelType.ONNX)) { - String path = model.source(); - String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString(); - if (path.startsWith(applicationPath)) { - path = path.substring(applicationPath.length() + 1); - } - loadModelInfo(onnxModels, model.name(), path); - } - return onnxModels; - } - - private OnnxModels onnxModelInfoFromStore(String modelName) { - OnnxModels onnxModels = new OnnxModels(); - String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString(); - loadModelInfo(onnxModels, modelName, path); - return onnxModels; - } - - private void loadModelInfo(OnnxModels onnModels, String name, String path) { - boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); - if ( ! modelExists) { - path = ApplicationPackage.MODELS_DIR.append(path).toString(); - modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); - } - if (modelExists) { - OnnxModel onnxModel = new OnnxModel(name, path); - OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), this.applicationPackage); - for (String onnxName : onnxModelInfo.getInputs()) { - onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); - } - for (String onnxName : onnxModelInfo.getOutputs()) { - onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); - } - onnxModel.setModelInfo(onnxModelInfo); - onnModels.add(onnxModel); - } - } - /** * Creates a rank profile not attached to any search definition, for each imported model in the application package, * and adds it to the given rank profile registry. @@ -353,6 +313,48 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false); } + private OnnxModels onnxModelInfoFromSource(ImportedMlModel model) { + OnnxModels onnxModels = new OnnxModels(); + if (model.modelType().equals(ImportedMlModel.ModelType.ONNX)) { + String path = model.source(); + String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString(); + if (path.startsWith(applicationPath)) { + path = path.substring(applicationPath.length() + 1); + } + loadModelInfo(onnxModels, model.name(), path); + } + return onnxModels; + } + + private OnnxModels onnxModelInfoFromStore(String modelName) { + OnnxModels onnxModels = new OnnxModels(); + String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString(); + loadModelInfo(onnxModels, modelName, path); + return onnxModels; + } + + private void loadModelInfo(OnnxModels onnModels, String name, String path) { + boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); + if ( ! modelExists) { + path = ApplicationPackage.MODELS_DIR.append(path).toString(); + modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage); + } + if (modelExists) { + OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(path, this.applicationPackage); + if (onnxModelInfo.getModelPath() != null) { + OnnxModel onnxModel = new OnnxModel(name, onnxModelInfo.getModelPath()); + for (String onnxName : onnxModelInfo.getInputs()) { + onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } + for (String onnxName : onnxModelInfo.getOutputs()) { + onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); + } + onnxModel.setModelInfo(onnxModelInfo); + onnModels.add(onnxModel); + } + } + } + /** Returns the global rank profiles as a rank profile list */ public RankProfileList rankProfileList() { return rankProfileList; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 0a838e5d915..1fe7bdfe284 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -36,17 +36,23 @@ import java.util.stream.Collectors; */ public class OnnxModelInfo { + private final String modelPath; private final String defaultOutput; private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); - private OnnxModelInfo(Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + private OnnxModelInfo(String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); this.defaultOutput = defaultOutput; } + public String getModelPath() { + return modelPath; + } + public Set<String> getInputs() { return inputs.keySet(); } @@ -147,7 +153,7 @@ public class OnnxModelInfo { static private OnnxModelInfo loadFromFile(Path path, ApplicationPackage app) { try (InputStream inputStream = app.getFile(path).createInputStream()) { Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); - String json = onnxModelToJson(model); + String json = onnxModelToJson(model, path); storeGeneratedInfo(json, path, app); return jsonToModelInfo(json); } catch (IOException e) { @@ -178,11 +184,12 @@ public class OnnxModelInfo { return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); } - static private String onnxModelToJson(Onnx.ModelProto model) throws IOException { + static private String onnxModelToJson(Onnx.ModelProto model, Path path) throws IOException { ByteArrayOutputStream out = new ByteArrayOutputStream(); JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); g.writeStartObject(); + g.writeStringField("path", path.toString()); g.writeArrayFieldStart("inputs"); for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) { onnxTypeToJson(g, valueInfo); @@ -207,6 +214,10 @@ public class OnnxModelInfo { Map<String, OnnxTypeInfo> outputs = new HashMap<>(); String defaultOutput = ""; + String path = null; + if (root.has("path")) { + path = root.get("path").textValue(); + } for (JsonNode input : root.get("inputs")) { inputs.put(input.get("name").textValue(), jsonToTypeInfo(input)); } @@ -216,7 +227,7 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(inputs, outputs, defaultOutput); + return new OnnxModelInfo(path, inputs, outputs, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { |