summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-21 14:31:44 +0200
committerLester Solbakken <lesters@oath.com>2021-05-21 14:31:44 +0200
commit4feb552579339c370db15c1c433f9269336506b8 (patch)
tree21e1116b6e2a51c02f0e632b3dba30b1c65ce2b7
parent6448742f804482946a7bf2d17723dca6b4100b73 (diff)
Add model path to stored model info
-rw-r--r--application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnxbin31758 -> 31694 bytes
-rw-r--r--application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java82
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java19
4 files changed, 59 insertions, 46 deletions
diff --git a/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx b/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx
index a86019bf53a..1733eca2141 100644
--- a/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx
+++ b/application/src/test/app-packages/model-evaluation/models/onnx/mnist_softmax.onnx
Binary files differ
diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
index 369b2cbe42b..4b51b244d2d 100644
--- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
+++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
@@ -60,8 +60,8 @@ public class ContainerModelEvaluationTest {
}
{
- String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.36716532707214355}]}";
- assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default.output/eval?input=" + inputTensor(), expected, jdisc);
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.3671652674674988}]}";
+ assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/output/eval?input=" + inputTensor(), expected, jdisc);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 5098796e409..f4cf5e500dd 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -281,46 +281,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
serviceClusters.addAll(builder.getClusters(deployState, this));
}
- private OnnxModels onnxModelInfoFromSource(ImportedMlModel model) {
- OnnxModels onnxModels = new OnnxModels();
- if (model.modelType().equals(ImportedMlModel.ModelType.ONNX)) {
- String path = model.source();
- String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString();
- if (path.startsWith(applicationPath)) {
- path = path.substring(applicationPath.length() + 1);
- }
- loadModelInfo(onnxModels, model.name(), path);
- }
- return onnxModels;
- }
-
- private OnnxModels onnxModelInfoFromStore(String modelName) {
- OnnxModels onnxModels = new OnnxModels();
- String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString();
- loadModelInfo(onnxModels, modelName, path);
- return onnxModels;
- }
-
- private void loadModelInfo(OnnxModels onnModels, String name, String path) {
- boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
- if ( ! modelExists) {
- path = ApplicationPackage.MODELS_DIR.append(path).toString();
- modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
- }
- if (modelExists) {
- OnnxModel onnxModel = new OnnxModel(name, path);
- OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(onnxModel.getFileName(), this.applicationPackage);
- for (String onnxName : onnxModelInfo.getInputs()) {
- onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
- }
- for (String onnxName : onnxModelInfo.getOutputs()) {
- onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
- }
- onnxModel.setModelInfo(onnxModelInfo);
- onnModels.add(onnxModel);
- }
- }
-
/**
* Creates a rank profile not attached to any search definition, for each imported model in the application package,
* and adds it to the given rank profile registry.
@@ -353,6 +313,48 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
new Processing().processRankProfiles(deployLogger, rankProfileRegistry, queryProfiles, true, false);
}
+ private OnnxModels onnxModelInfoFromSource(ImportedMlModel model) {
+ OnnxModels onnxModels = new OnnxModels();
+ if (model.modelType().equals(ImportedMlModel.ModelType.ONNX)) {
+ String path = model.source();
+ String applicationPath = this.applicationPackage.getFileReference(Path.fromString("")).toString();
+ if (path.startsWith(applicationPath)) {
+ path = path.substring(applicationPath.length() + 1);
+ }
+ loadModelInfo(onnxModels, model.name(), path);
+ }
+ return onnxModels;
+ }
+
+ private OnnxModels onnxModelInfoFromStore(String modelName) {
+ OnnxModels onnxModels = new OnnxModels();
+ String path = ApplicationPackage.MODELS_DIR.append(modelName + ".onnx").toString();
+ loadModelInfo(onnxModels, modelName, path);
+ return onnxModels;
+ }
+
+ private void loadModelInfo(OnnxModels onnModels, String name, String path) {
+ boolean modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
+ if ( ! modelExists) {
+ path = ApplicationPackage.MODELS_DIR.append(path).toString();
+ modelExists = OnnxModelInfo.modelExists(path, this.applicationPackage);
+ }
+ if (modelExists) {
+ OnnxModelInfo onnxModelInfo = OnnxModelInfo.load(path, this.applicationPackage);
+ if (onnxModelInfo.getModelPath() != null) {
+ OnnxModel onnxModel = new OnnxModel(name, onnxModelInfo.getModelPath());
+ for (String onnxName : onnxModelInfo.getInputs()) {
+ onnxModel.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
+ for (String onnxName : onnxModelInfo.getOutputs()) {
+ onnxModel.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
+ }
+ onnxModel.setModelInfo(onnxModelInfo);
+ onnModels.add(onnxModel);
+ }
+ }
+ }
+
/** Returns the global rank profiles as a rank profile list */
public RankProfileList rankProfileList() { return rankProfileList; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 0a838e5d915..1fe7bdfe284 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -36,17 +36,23 @@ import java.util.stream.Collectors;
*/
public class OnnxModelInfo {
+ private final String modelPath;
private final String defaultOutput;
private final Map<String, OnnxTypeInfo> inputs;
private final Map<String, OnnxTypeInfo> outputs;
private final Map<String, TensorType> vespaTypes = new HashMap<>();
- private OnnxModelInfo(Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ private OnnxModelInfo(String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ this.modelPath = path;
this.inputs = Collections.unmodifiableMap(inputs);
this.outputs = Collections.unmodifiableMap(outputs);
this.defaultOutput = defaultOutput;
}
+ public String getModelPath() {
+ return modelPath;
+ }
+
public Set<String> getInputs() {
return inputs.keySet();
}
@@ -147,7 +153,7 @@ public class OnnxModelInfo {
static private OnnxModelInfo loadFromFile(Path path, ApplicationPackage app) {
try (InputStream inputStream = app.getFile(path).createInputStream()) {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- String json = onnxModelToJson(model);
+ String json = onnxModelToJson(model, path);
storeGeneratedInfo(json, path, app);
return jsonToModelInfo(json);
} catch (IOException e) {
@@ -178,11 +184,12 @@ public class OnnxModelInfo {
return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
}
- static private String onnxModelToJson(Onnx.ModelProto model) throws IOException {
+ static private String onnxModelToJson(Onnx.ModelProto model, Path path) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
g.writeStartObject();
+ g.writeStringField("path", path.toString());
g.writeArrayFieldStart("inputs");
for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
onnxTypeToJson(g, valueInfo);
@@ -207,6 +214,10 @@ public class OnnxModelInfo {
Map<String, OnnxTypeInfo> outputs = new HashMap<>();
String defaultOutput = "";
+ String path = null;
+ if (root.has("path")) {
+ path = root.get("path").textValue();
+ }
for (JsonNode input : root.get("inputs")) {
inputs.put(input.get("name").textValue(), jsonToTypeInfo(input));
}
@@ -216,7 +227,7 @@ public class OnnxModelInfo {
if (root.get("outputs").has(0)) {
defaultOutput = root.get("outputs").get(0).get("name").textValue();
}
- return new OnnxModelInfo(inputs, outputs, defaultOutput);
+ return new OnnxModelInfo(path, inputs, outputs, defaultOutput);
}
static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {