diff options
-rwxr-xr-x | model-integration/src/main/python/vespa-convert-tf2onnx.py | 127 |
1 files changed, 102 insertions, 25 deletions
diff --git a/model-integration/src/main/python/vespa-convert-tf2onnx.py b/model-integration/src/main/python/vespa-convert-tf2onnx.py index e34610f6eb4..9c8cb89ad5c 100755 --- a/model-integration/src/main/python/vespa-convert-tf2onnx.py +++ b/model-integration/src/main/python/vespa-convert-tf2onnx.py @@ -2,53 +2,130 @@ # Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -import sys import onnx - from tf2onnx import convert from tensorflow.python.tools import saved_model_utils -def find(nodes, test): - return next((x for x in nodes if test(x)), None) +def find(seq, test, default=None): + return next((x for x in seq if test(x)), default) + + +def index_of(seq, elem): + for i in range(len(seq)): + if seq[i] == elem: + return i + + +def has_initializer(onnx_model, name): + return find(onnx_model.graph.initializer, lambda i: i.name == name) != None + + +def has_equivalent_shape(onnx_tensor_shape, tensorflow_tensor_shape): + onnx_dims = onnx_tensor_shape.dim + tf_dims = tensorflow_tensor_shape.dim + if len(onnx_dims) != len(tf_dims): + return False + for i in range(len(onnx_dims)): + onnx_dim_size = onnx_dims[i].dim_value + tf_dim_size = tf_dims[i].size + if onnx_dim_size == 0 and tf_dim_size == -1: + continue + if onnx_dim_size != tf_dim_size: + return False + return True + + +def find_by_type(seq, tensor): + return [ item for item in seq if has_equivalent_shape(item.type.tensor_type.shape, tensor.tensor_shape) ] + + +def rename_input(onnx_model, signature_name, signature_tensor): + signature_node_name = signature_tensor.name + graph_input = find(onnx_model.graph.input, lambda input: input.name == signature_node_name) + if graph_input is None: + print("TensorFlow signature input '{}' references node '{}' which was not found. Trying to find equivalent input.".format(signature_name, signature_node_name)) + candidates = find_by_type(onnx_model.graph.input, signature_tensor) + if len(candidates) == 0: + print("Could not find equivalent input for '{}'. Unable to rename this input.".format(signature_name)) + return + if len(candidates) > 1: + print("Found multiple equivalent inputs '{}'. Unable to rename.".format(",".join([ o.name for o in candidates ]))) + return + graph_input = candidates[0] + print("Found equivalent input '{}'. Assuming this is correct.".format(graph_input.name)) -def make_alias(onnx_model, alias, output_name): - output = find(onnx_model.graph.output, lambda node: node.name == output_name) - if output is None: - print("Could not find output '{}' to alias from '{}'".format(output_name, alias)) + if graph_input.name == signature_name: + print("Signature input '{}' already exists. Skipping.".format(signature_name)) return - output_tensor = onnx.helper.make_empty_tensor_value_info("") - output_tensor.CopyFrom(output) - output_tensor.name = alias - onnx_model.graph.output.append(output_tensor) - onnx_model.graph.node.append(onnx.helper.make_node("Identity", [output_name], [alias])) + input_node = find(onnx_model.graph.node, lambda node: graph_input.name in node.input) + if input_node is None: + print("Node using graph input '{}' was not found. Unable to rename.".format(graph_input.name)) + return + + print("Renamed input from '{}' to '{}'".format(graph_input.name, signature_name)) + input_node.input[index_of(input_node.input, graph_input.name)] = signature_name + graph_input.name = signature_name + + +def rename_output(onnx_model, signature_name, signature_tensor): + signature_node_name = signature_tensor.name + + graph_output = find(onnx_model.graph.output, lambda output: output.name == signature_node_name) + if graph_output is None: + print("TensorFlow signature output '{}' references node '{}' which was not found. Trying to find equivalent output.".format(signature_name, signature_node_name)) + candidates = find_by_type(onnx_model.graph.output, signature_tensor) + if len(candidates) == 0: + print("Could not find equivalent output for '{}'. Unable to rename this output.".format(signature_name)) + return + if len(candidates) > 1: + print("Found multiple equivalent outputs '{}'. Unable to rename.".format(",".join([ o.name for o in candidates ]))) + return + graph_output = candidates[0] + print("Found equivalent output '{}'. Assuming this is correct.".format(graph_output.name)) + + if graph_output.name == signature_name: + print("Signature output '{}' already exists. Skipping.".format(signature_name)) + return + + output_node = find(onnx_model.graph.node, lambda node: graph_output.name in node.output) + if output_node is None: + print("Node generating graph output '{}' was not found. Unable to rename.".format(graph_output.name)) + return + + print("Renamed output from '{}' to '{}'".format(graph_output.name, signature_name)) + output_node.output[index_of(output_node.output, graph_output.name)] = signature_name + graph_output.name = signature_name -def verify_outputs(args, onnx_model): + +def verify_inputs_and_outputs(args, onnx_model): tag_sets = saved_model_utils.get_saved_model_tag_sets(args.saved_model) - for tag_set in sorted(tag_sets): + for tag_set in tag_sets: tag_set = ','.join(tag_set) meta_graph_def = saved_model_utils.get_meta_graph_def(args.saved_model, tag_set) signature_def_map = meta_graph_def.signature_def - for signature_def_key in sorted(signature_def_map.keys()): + for signature_def_key in signature_def_map.keys(): + + inputs_tensor_info = signature_def_map[signature_def_key].inputs + for input_key, input_tensor in inputs_tensor_info.items(): + rename_input(onnx_model, input_key, input_tensor) + outputs_tensor_info = signature_def_map[signature_def_key].outputs - for output_key, output_tensor in sorted(outputs_tensor_info.items()): - output_key_exists_as_output = find(onnx_model.graph.output, lambda node: node.name == output_key) - if output_key_exists_as_output: - continue - make_alias(onnx_model, output_key, output_tensor.name) + for output_key, output_tensor in outputs_tensor_info.items(): + rename_output(onnx_model, output_key, output_tensor) - output_names = [ "'{}'".format(o.name) for o in onnx_model.graph.output ] - print("Outputs in model: {}".format(", ".join(output_names))) + print("Inputs in model: {}".format(", ".join(["'{}'".format(o.name) for o in onnx_model.graph.input if not has_initializer(onnx_model, o.name)]))) + print("Outputs in model: {}".format(", ".join(["'{}'".format(o.name) for o in onnx_model.graph.output]))) def main(): + args = convert.get_args() convert.main() - args = convert.get_args() onnx_model = onnx.load(args.output) - verify_outputs(args, onnx_model) + verify_inputs_and_outputs(args, onnx_model) onnx.save(onnx_model, args.output) |