aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-08-20 18:31:44 +0200
committerLester Solbakken <lesters@oath.com>2020-08-20 18:31:44 +0200
commit8d9037cc4e7d0d15d4932b116fd2fbebc9e3152a (patch)
tree762e30d90fa647c8a35b9de1297694b5fe976ccb /model-integration
parentb214393ad98132794d6c7a09f00599541c6ed372 (diff)
Only rename outputs for tf -> onnx conversion
Diffstat (limited to 'model-integration')
-rwxr-xr-xmodel-integration/src/main/python/vespa-convert-tf2onnx.py39
1 files changed, 2 insertions, 37 deletions
diff --git a/model-integration/src/main/python/vespa-convert-tf2onnx.py b/model-integration/src/main/python/vespa-convert-tf2onnx.py
index 9c8cb89ad5c..1862072aebc 100755
--- a/model-integration/src/main/python/vespa-convert-tf2onnx.py
+++ b/model-integration/src/main/python/vespa-convert-tf2onnx.py
@@ -40,36 +40,6 @@ def find_by_type(seq, tensor):
return [ item for item in seq if has_equivalent_shape(item.type.tensor_type.shape, tensor.tensor_shape) ]
-def rename_input(onnx_model, signature_name, signature_tensor):
- signature_node_name = signature_tensor.name
-
- graph_input = find(onnx_model.graph.input, lambda input: input.name == signature_node_name)
- if graph_input is None:
- print("TensorFlow signature input '{}' references node '{}' which was not found. Trying to find equivalent input.".format(signature_name, signature_node_name))
- candidates = find_by_type(onnx_model.graph.input, signature_tensor)
- if len(candidates) == 0:
- print("Could not find equivalent input for '{}'. Unable to rename this input.".format(signature_name))
- return
- if len(candidates) > 1:
- print("Found multiple equivalent inputs '{}'. Unable to rename.".format(",".join([ o.name for o in candidates ])))
- return
- graph_input = candidates[0]
- print("Found equivalent input '{}'. Assuming this is correct.".format(graph_input.name))
-
- if graph_input.name == signature_name:
- print("Signature input '{}' already exists. Skipping.".format(signature_name))
- return
-
- input_node = find(onnx_model.graph.node, lambda node: graph_input.name in node.input)
- if input_node is None:
- print("Node using graph input '{}' was not found. Unable to rename.".format(graph_input.name))
- return
-
- print("Renamed input from '{}' to '{}'".format(graph_input.name, signature_name))
- input_node.input[index_of(input_node.input, graph_input.name)] = signature_name
- graph_input.name = signature_name
-
-
def rename_output(onnx_model, signature_name, signature_tensor):
signature_node_name = signature_tensor.name
@@ -100,18 +70,13 @@ def rename_output(onnx_model, signature_name, signature_tensor):
graph_output.name = signature_name
-def verify_inputs_and_outputs(args, onnx_model):
+def verify_outputs(args, onnx_model):
tag_sets = saved_model_utils.get_saved_model_tag_sets(args.saved_model)
for tag_set in tag_sets:
tag_set = ','.join(tag_set)
meta_graph_def = saved_model_utils.get_meta_graph_def(args.saved_model, tag_set)
signature_def_map = meta_graph_def.signature_def
for signature_def_key in signature_def_map.keys():
-
- inputs_tensor_info = signature_def_map[signature_def_key].inputs
- for input_key, input_tensor in inputs_tensor_info.items():
- rename_input(onnx_model, input_key, input_tensor)
-
outputs_tensor_info = signature_def_map[signature_def_key].outputs
for output_key, output_tensor in outputs_tensor_info.items():
rename_output(onnx_model, output_key, output_tensor)
@@ -125,7 +90,7 @@ def main():
convert.main()
onnx_model = onnx.load(args.output)
- verify_inputs_and_outputs(args, onnx_model)
+ verify_outputs(args, onnx_model)
onnx.save(onnx_model, args.output)