diff options
author | Jon Bratseth <bratseth@oath.com> | 2020-08-20 18:35:50 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-20 18:35:50 +0200 |
commit | 683a3309292109ad03e49d09629384d6528bb7c3 (patch) | |
tree | 512a16f04fa5f32c4647d2625e6c8fc6487820c4 | |
parent | 37d7096694fc91e370d14b9fb8a69825b9307e5f (diff) | |
parent | 8d9037cc4e7d0d15d4932b116fd2fbebc9e3152a (diff) |
Merge pull request #14119 from vespa-engine/lesters/tf2onnx-only-rename-outputs
Only rename outputs for tf -> onnx conversion
-rwxr-xr-x | model-integration/src/main/python/vespa-convert-tf2onnx.py | 39 |
1 files changed, 2 insertions, 37 deletions
diff --git a/model-integration/src/main/python/vespa-convert-tf2onnx.py b/model-integration/src/main/python/vespa-convert-tf2onnx.py index 9c8cb89ad5c..1862072aebc 100755 --- a/model-integration/src/main/python/vespa-convert-tf2onnx.py +++ b/model-integration/src/main/python/vespa-convert-tf2onnx.py @@ -40,36 +40,6 @@ def find_by_type(seq, tensor): return [ item for item in seq if has_equivalent_shape(item.type.tensor_type.shape, tensor.tensor_shape) ] -def rename_input(onnx_model, signature_name, signature_tensor): - signature_node_name = signature_tensor.name - - graph_input = find(onnx_model.graph.input, lambda input: input.name == signature_node_name) - if graph_input is None: - print("TensorFlow signature input '{}' references node '{}' which was not found. Trying to find equivalent input.".format(signature_name, signature_node_name)) - candidates = find_by_type(onnx_model.graph.input, signature_tensor) - if len(candidates) == 0: - print("Could not find equivalent input for '{}'. Unable to rename this input.".format(signature_name)) - return - if len(candidates) > 1: - print("Found multiple equivalent inputs '{}'. Unable to rename.".format(",".join([ o.name for o in candidates ]))) - return - graph_input = candidates[0] - print("Found equivalent input '{}'. Assuming this is correct.".format(graph_input.name)) - - if graph_input.name == signature_name: - print("Signature input '{}' already exists. Skipping.".format(signature_name)) - return - - input_node = find(onnx_model.graph.node, lambda node: graph_input.name in node.input) - if input_node is None: - print("Node using graph input '{}' was not found. Unable to rename.".format(graph_input.name)) - return - - print("Renamed input from '{}' to '{}'".format(graph_input.name, signature_name)) - input_node.input[index_of(input_node.input, graph_input.name)] = signature_name - graph_input.name = signature_name - - def rename_output(onnx_model, signature_name, signature_tensor): signature_node_name = signature_tensor.name @@ -100,18 +70,13 @@ def rename_output(onnx_model, signature_name, signature_tensor): graph_output.name = signature_name -def verify_inputs_and_outputs(args, onnx_model): +def verify_outputs(args, onnx_model): tag_sets = saved_model_utils.get_saved_model_tag_sets(args.saved_model) for tag_set in tag_sets: tag_set = ','.join(tag_set) meta_graph_def = saved_model_utils.get_meta_graph_def(args.saved_model, tag_set) signature_def_map = meta_graph_def.signature_def for signature_def_key in signature_def_map.keys(): - - inputs_tensor_info = signature_def_map[signature_def_key].inputs - for input_key, input_tensor in inputs_tensor_info.items(): - rename_input(onnx_model, input_key, input_tensor) - outputs_tensor_info = signature_def_map[signature_def_key].outputs for output_key, output_tensor in outputs_tensor_info.items(): rename_output(onnx_model, output_key, output_tensor) @@ -125,7 +90,7 @@ def main(): convert.main() onnx_model = onnx.load(args.output) - verify_inputs_and_outputs(args, onnx_model) + verify_outputs(args, onnx_model) onnx.save(onnx_model, args.output) |