diff options
Diffstat (limited to 'eval/src/apps')
-rw-r--r-- | eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp | 62 |
1 files changed, 55 insertions, 7 deletions
diff --git a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp index 506073ae8b3..868e9d036f1 100644 --- a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp +++ b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp @@ -20,6 +20,10 @@ using vespalib::FilePointer; using namespace vespalib::eval; using namespace vespalib::eval::test; +struct MyError { + vespalib::string msg; +}; + bool read_line(FilePointer &file, vespalib::string &line) { char line_buffer[1024]; char *res = fgets(line_buffer, sizeof(line_buffer), file.fp()); @@ -139,16 +143,50 @@ struct MakeInputType { } }; +vespalib::string make_bound_str(const std::map<vespalib::string,size_t> &bound) { + vespalib::string result; + if (!bound.empty()) { + for (const auto &[name, size]: bound) { + if (result.empty()) { + result.append(" ("); + } else { + result.append(","); + } + result.append(fmt("%s=%zu", name.c_str(), size)); + } + result.append(")"); + } + return result; +} + +void bind_input(Onnx::WirePlanner &planner, const Onnx::TensorInfo &input, const ValueType &type) { + auto bound = planner.get_bound_sizes(input); + if (!planner.bind_input_type(type, input)) { + auto bound_str = make_bound_str(bound); + throw MyError{fmt("incompatible type for input '%s': %s -> %s%s", + input.name.c_str(), type.to_spec().c_str(), input.type_as_string().c_str(), bound_str.c_str())}; + } +} + +ValueType make_output(const Onnx::WirePlanner &planner, const Onnx::TensorInfo &output) { + auto type = planner.make_output_type(output); + if (type.is_error()) { + throw MyError{fmt("unable to make compatible type for output '%s': %s -> error", + output.name.c_str(), output.type_as_string().c_str())}; + } + return type; +} + Onnx::WireInfo make_plan(Options &opts, const Onnx &model) { Onnx::WirePlanner planner; MakeInputType make_input_type(opts); for (const auto &input: model.inputs()) { auto type = make_input_type(input); - REQUIRE(planner.bind_input_type(type, input)); + bind_input(planner, input, type); } planner.prepare_output_types(model); for (const auto &output: model.outputs()) { - REQUIRE(!planner.make_output_type(output).is_error()); + make_output(planner, output); } return planner.get_wire_info(model); } @@ -189,7 +227,7 @@ int probe_types() { StdOut std_out; Slime params; if (!JsonFormat::decode(std_in, params)) { - return 3; + throw MyError{"invalid json"}; } Slime result; auto &root = result.setObject(); @@ -199,13 +237,20 @@ int probe_types() { for (size_t i = 0; i < model.inputs().size(); ++i) { auto spec = params["inputs"][model.inputs()[i].name].asString().make_string(); auto input_type = ValueType::from_spec(spec); - REQUIRE(!input_type.is_error()); - REQUIRE(planner.bind_input_type(input_type, model.inputs()[i])); + if (input_type.is_error()) { + if (!params["inputs"][model.inputs()[i].name].valid()) { + throw MyError{fmt("missing type for model input '%s'", + model.inputs()[i].name.c_str())}; + } else { + throw MyError{fmt("invalid type for model input '%s': '%s'", + model.inputs()[i].name.c_str(), spec.c_str())}; + } + } + bind_input(planner, model.inputs()[i], input_type); } planner.prepare_output_types(model); for (const auto &output: model.outputs()) { - auto output_type = planner.make_output_type(output); - REQUIRE(!output_type.is_error()); + auto output_type = make_output(planner, output); types.setString(output.name, output_type.to_spec()); } write_compact(result, std_out); @@ -253,6 +298,9 @@ int my_main(int argc, char **argv) { int main(int argc, char **argv) { try { return my_main(argc, argv); + } catch (const MyError &err) { + fprintf(stderr, "error: %s\n", err.msg.c_str()); + return 3; } catch (const std::exception &ex) { fprintf(stderr, "got exception: %s\n", ex.what()); return 2; |