summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-08-19 11:40:02 +0200
committerGitHub <noreply@github.com>2020-08-19 11:40:02 +0200
commit4fa221b7c20f492de0433b9b812f1cc509a18d97 (patch)
treee0cb14b33159c8f7e9ec4ee7eb712b896062c964
parent5fa4c0b53645b86537fd4441144f30f86c15a930 (diff)
parent2175474ab0929b81e3a47a558e8cb6c86a4b667f (diff)
Merge pull request #14100 from vespa-engine/lesters/tf2onnx-wrapper
Wrap tf2onnx to add output aliases
-rw-r--r--model-integration/CMakeLists.txt2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java2
-rwxr-xr-xmodel-integration/src/main/python/vespa-convert-tf2onnx.py60
3 files changed, 63 insertions, 1 deletions
diff --git a/model-integration/CMakeLists.txt b/model-integration/CMakeLists.txt
index 26d5b3d1bbc..f8aa1c552a6 100644
--- a/model-integration/CMakeLists.txt
+++ b/model-integration/CMakeLists.txt
@@ -1,4 +1,6 @@
# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
install_fat_java_artifact(model-integration)
+vespa_install_script(src/main/python/vespa-convert-tf2onnx.py vespa-convert-tf2onnx bin)
+
install(FILES src/main/config/model-integration.xml DESTINATION conf/configserver-app) \ No newline at end of file
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index 96ea58edc61..34b9c847a12 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -98,7 +98,7 @@ public class TensorFlowImporter extends ModelImporter {
private Pair<Integer, String> convertToOnnx(String savedModel, String output, int opset) throws IOException {
ProcessExecuter executer = new ProcessExecuter();
- String job = "python3 -m tf2onnx.convert --saved-model " + savedModel + " --output " + output + " --opset " + opset;
+ String job = "vespa-convert-tf2onnx --saved-model " + savedModel + " --output " + output + " --opset " + opset;
return executer.exec(job);
}
diff --git a/model-integration/src/main/python/vespa-convert-tf2onnx.py b/model-integration/src/main/python/vespa-convert-tf2onnx.py
new file mode 100755
index 00000000000..e34610f6eb4
--- /dev/null
+++ b/model-integration/src/main/python/vespa-convert-tf2onnx.py
@@ -0,0 +1,60 @@
+#! /usr/bin/env python3
+
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import sys
+import onnx
+
+from tf2onnx import convert
+from tensorflow.python.tools import saved_model_utils
+
+
+def find(nodes, test):
+ return next((x for x in nodes if test(x)), None)
+
+
+def make_alias(onnx_model, alias, output_name):
+ output = find(onnx_model.graph.output, lambda node: node.name == output_name)
+ if output is None:
+ print("Could not find output '{}' to alias from '{}'".format(output_name, alias))
+ return
+ output_tensor = onnx.helper.make_empty_tensor_value_info("")
+ output_tensor.CopyFrom(output)
+ output_tensor.name = alias
+ onnx_model.graph.output.append(output_tensor)
+ onnx_model.graph.node.append(onnx.helper.make_node("Identity", [output_name], [alias]))
+
+
+def verify_outputs(args, onnx_model):
+ tag_sets = saved_model_utils.get_saved_model_tag_sets(args.saved_model)
+ for tag_set in sorted(tag_sets):
+ tag_set = ','.join(tag_set)
+ meta_graph_def = saved_model_utils.get_meta_graph_def(args.saved_model, tag_set)
+ signature_def_map = meta_graph_def.signature_def
+ for signature_def_key in sorted(signature_def_map.keys()):
+ outputs_tensor_info = signature_def_map[signature_def_key].outputs
+ for output_key, output_tensor in sorted(outputs_tensor_info.items()):
+ output_key_exists_as_output = find(onnx_model.graph.output, lambda node: node.name == output_key)
+ if output_key_exists_as_output:
+ continue
+ make_alias(onnx_model, output_key, output_tensor.name)
+
+ output_names = [ "'{}'".format(o.name) for o in onnx_model.graph.output ]
+ print("Outputs in model: {}".format(", ".join(output_names)))
+
+
+def main():
+ convert.main()
+
+ args = convert.get_args()
+ onnx_model = onnx.load(args.output)
+ verify_outputs(args, onnx_model)
+ onnx.save(onnx_model, args.output)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+