From 2175474ab0929b81e3a47a558e8cb6c86a4b667f Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 19 Aug 2020 11:26:27 +0200 Subject: Wrap tf2onnx to add output aliases --- model-integration/CMakeLists.txt | 2 + .../importer/tensorflow/TensorFlowImporter.java | 2 +- .../src/main/python/vespa-convert-tf2onnx.py | 60 ++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100755 model-integration/src/main/python/vespa-convert-tf2onnx.py (limited to 'model-integration') diff --git a/model-integration/CMakeLists.txt b/model-integration/CMakeLists.txt index 26d5b3d1bbc..f8aa1c552a6 100644 --- a/model-integration/CMakeLists.txt +++ b/model-integration/CMakeLists.txt @@ -1,4 +1,6 @@ # Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. install_fat_java_artifact(model-integration) +vespa_install_script(src/main/python/vespa-convert-tf2onnx.py vespa-convert-tf2onnx bin) + install(FILES src/main/config/model-integration.xml DESTINATION conf/configserver-app) \ No newline at end of file diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index 96ea58edc61..34b9c847a12 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -98,7 +98,7 @@ public class TensorFlowImporter extends ModelImporter { private Pair convertToOnnx(String savedModel, String output, int opset) throws IOException { ProcessExecuter executer = new ProcessExecuter(); - String job = "python3 -m tf2onnx.convert --saved-model " + savedModel + " --output " + output + " --opset " + opset; + String job = "vespa-convert-tf2onnx --saved-model " + savedModel + " --output " + output + " --opset " + opset; return executer.exec(job); } diff --git a/model-integration/src/main/python/vespa-convert-tf2onnx.py b/model-integration/src/main/python/vespa-convert-tf2onnx.py new file mode 100755 index 00000000000..e34610f6eb4 --- /dev/null +++ b/model-integration/src/main/python/vespa-convert-tf2onnx.py @@ -0,0 +1,60 @@ +#! /usr/bin/env python3 + +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import sys +import onnx + +from tf2onnx import convert +from tensorflow.python.tools import saved_model_utils + + +def find(nodes, test): + return next((x for x in nodes if test(x)), None) + + +def make_alias(onnx_model, alias, output_name): + output = find(onnx_model.graph.output, lambda node: node.name == output_name) + if output is None: + print("Could not find output '{}' to alias from '{}'".format(output_name, alias)) + return + output_tensor = onnx.helper.make_empty_tensor_value_info("") + output_tensor.CopyFrom(output) + output_tensor.name = alias + onnx_model.graph.output.append(output_tensor) + onnx_model.graph.node.append(onnx.helper.make_node("Identity", [output_name], [alias])) + + +def verify_outputs(args, onnx_model): + tag_sets = saved_model_utils.get_saved_model_tag_sets(args.saved_model) + for tag_set in sorted(tag_sets): + tag_set = ','.join(tag_set) + meta_graph_def = saved_model_utils.get_meta_graph_def(args.saved_model, tag_set) + signature_def_map = meta_graph_def.signature_def + for signature_def_key in sorted(signature_def_map.keys()): + outputs_tensor_info = signature_def_map[signature_def_key].outputs + for output_key, output_tensor in sorted(outputs_tensor_info.items()): + output_key_exists_as_output = find(onnx_model.graph.output, lambda node: node.name == output_key) + if output_key_exists_as_output: + continue + make_alias(onnx_model, output_key, output_tensor.name) + + output_names = [ "'{}'".format(o.name) for o in onnx_model.graph.output ] + print("Outputs in model: {}".format(", ".join(output_names))) + + +def main(): + convert.main() + + args = convert.get_args() + onnx_model = onnx.load(args.output) + verify_outputs(args, onnx_model) + onnx.save(onnx_model, args.output) + + +if __name__ == "__main__": + main() + + + + -- cgit v1.2.3