aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/apps
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/apps')
-rw-r--r--eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp62
1 files changed, 55 insertions, 7 deletions
diff --git a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp
index 506073ae8b3..868e9d036f1 100644
--- a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp
+++ b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp
@@ -20,6 +20,10 @@ using vespalib::FilePointer;
using namespace vespalib::eval;
using namespace vespalib::eval::test;
+struct MyError {
+ vespalib::string msg;
+};
+
bool read_line(FilePointer &file, vespalib::string &line) {
char line_buffer[1024];
char *res = fgets(line_buffer, sizeof(line_buffer), file.fp());
@@ -139,16 +143,50 @@ struct MakeInputType {
}
};
+vespalib::string make_bound_str(const std::map<vespalib::string,size_t> &bound) {
+ vespalib::string result;
+ if (!bound.empty()) {
+ for (const auto &[name, size]: bound) {
+ if (result.empty()) {
+ result.append(" (");
+ } else {
+ result.append(",");
+ }
+ result.append(fmt("%s=%zu", name.c_str(), size));
+ }
+ result.append(")");
+ }
+ return result;
+}
+
+void bind_input(Onnx::WirePlanner &planner, const Onnx::TensorInfo &input, const ValueType &type) {
+ auto bound = planner.get_bound_sizes(input);
+ if (!planner.bind_input_type(type, input)) {
+ auto bound_str = make_bound_str(bound);
+ throw MyError{fmt("incompatible type for input '%s': %s -> %s%s",
+ input.name.c_str(), type.to_spec().c_str(), input.type_as_string().c_str(), bound_str.c_str())};
+ }
+}
+
+ValueType make_output(const Onnx::WirePlanner &planner, const Onnx::TensorInfo &output) {
+ auto type = planner.make_output_type(output);
+ if (type.is_error()) {
+ throw MyError{fmt("unable to make compatible type for output '%s': %s -> error",
+ output.name.c_str(), output.type_as_string().c_str())};
+ }
+ return type;
+}
+
Onnx::WireInfo make_plan(Options &opts, const Onnx &model) {
Onnx::WirePlanner planner;
MakeInputType make_input_type(opts);
for (const auto &input: model.inputs()) {
auto type = make_input_type(input);
- REQUIRE(planner.bind_input_type(type, input));
+ bind_input(planner, input, type);
}
planner.prepare_output_types(model);
for (const auto &output: model.outputs()) {
- REQUIRE(!planner.make_output_type(output).is_error());
+ make_output(planner, output);
}
return planner.get_wire_info(model);
}
@@ -189,7 +227,7 @@ int probe_types() {
StdOut std_out;
Slime params;
if (!JsonFormat::decode(std_in, params)) {
- return 3;
+ throw MyError{"invalid json"};
}
Slime result;
auto &root = result.setObject();
@@ -199,13 +237,20 @@ int probe_types() {
for (size_t i = 0; i < model.inputs().size(); ++i) {
auto spec = params["inputs"][model.inputs()[i].name].asString().make_string();
auto input_type = ValueType::from_spec(spec);
- REQUIRE(!input_type.is_error());
- REQUIRE(planner.bind_input_type(input_type, model.inputs()[i]));
+ if (input_type.is_error()) {
+ if (!params["inputs"][model.inputs()[i].name].valid()) {
+ throw MyError{fmt("missing type for model input '%s'",
+ model.inputs()[i].name.c_str())};
+ } else {
+ throw MyError{fmt("invalid type for model input '%s': '%s'",
+ model.inputs()[i].name.c_str(), spec.c_str())};
+ }
+ }
+ bind_input(planner, model.inputs()[i], input_type);
}
planner.prepare_output_types(model);
for (const auto &output: model.outputs()) {
- auto output_type = planner.make_output_type(output);
- REQUIRE(!output_type.is_error());
+ auto output_type = make_output(planner, output);
types.setString(output.name, output_type.to_spec());
}
write_compact(result, std_out);
@@ -253,6 +298,9 @@ int my_main(int argc, char **argv) {
int main(int argc, char **argv) {
try {
return my_main(argc, argv);
+ } catch (const MyError &err) {
+ fprintf(stderr, "error: %s\n", err.msg.c_str());
+ return 3;
} catch (const std::exception &ex) {
fprintf(stderr, "got exception: %s\n", ex.what());
return 2;