diff options
Diffstat (limited to 'eval/src/apps/analyze_onnx_model')
-rw-r--r-- | eval/src/apps/analyze_onnx_model/CMakeLists.txt | 3 | ||||
-rw-r--r-- | eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp | 44 |
2 files changed, 45 insertions, 2 deletions
diff --git a/eval/src/apps/analyze_onnx_model/CMakeLists.txt b/eval/src/apps/analyze_onnx_model/CMakeLists.txt index e2ed64cd8cc..dc89213f9eb 100644 --- a/eval/src/apps/analyze_onnx_model/CMakeLists.txt +++ b/eval/src/apps/analyze_onnx_model/CMakeLists.txt @@ -1,7 +1,8 @@ # Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_executable(vespa-analyze-onnx-model +vespa_add_executable(eval_analyze_onnx_model_app SOURCES analyze_onnx_model.cpp + OUTPUT_NAME vespa-analyze-onnx-model INSTALL bin DEPENDS vespaeval diff --git a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp index 2f22f903f2e..506073ae8b3 100644 --- a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp +++ b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp @@ -4,6 +4,7 @@ #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/value_codec.h> #include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/test/test_io.h> #include <vespa/vespalib/util/benchmark_timer.h> #include <vespa/vespalib/util/require.h> #include <vespa/vespalib/util/guard.h> @@ -11,8 +12,13 @@ using vespalib::make_string_short::fmt; +using vespalib::Slime; +using vespalib::slime::JsonFormat; +using vespalib::slime::Inspector; +using vespalib::slime::Cursor; using vespalib::FilePointer; using namespace vespalib::eval; +using namespace vespalib::eval::test; bool read_line(FilePointer &file, vespalib::string &line) { char line_buffer[1024]; @@ -169,14 +175,50 @@ int usage(const char *self) { fprintf(stderr, " load onnx model and report memory usage\n"); fprintf(stderr, " options are used to specify unknown values, like dimension sizes\n"); fprintf(stderr, " options are accepted in the order in which they are needed\n"); - fprintf(stderr, " tip: run without options first, to see which you need\n"); + fprintf(stderr, " tip: run without options first, to see which you need\n\n"); + fprintf(stderr, "usage: %s --probe-types\n", self); + fprintf(stderr, " use onnx model to infer/probe output types based on input types\n"); + fprintf(stderr, " parameters are read from stdin and results are written to stdout\n"); + fprintf(stderr, " input format (json): {model:<filename>, inputs:{<name>:vespa-type-string}}\n"); + fprintf(stderr, " output format (json): {outputs:{<name>:vespa-type-string}}\n"); return 1; } +int probe_types() { + StdIn std_in; + StdOut std_out; + Slime params; + if (!JsonFormat::decode(std_in, params)) { + return 3; + } + Slime result; + auto &root = result.setObject(); + auto &types = root.setObject("outputs"); + Onnx model(params["model"].asString().make_string(), Onnx::Optimize::DISABLE); + Onnx::WirePlanner planner; + for (size_t i = 0; i < model.inputs().size(); ++i) { + auto spec = params["inputs"][model.inputs()[i].name].asString().make_string(); + auto input_type = ValueType::from_spec(spec); + REQUIRE(!input_type.is_error()); + REQUIRE(planner.bind_input_type(input_type, model.inputs()[i])); + } + planner.prepare_output_types(model); + for (const auto &output: model.outputs()) { + auto output_type = planner.make_output_type(output); + REQUIRE(!output_type.is_error()); + types.setString(output.name, output_type.to_spec()); + } + write_compact(result, std_out); + return 0; +} + int my_main(int argc, char **argv) { if (argc < 2) { return usage(argv[0]); } + if ((argc == 2) && (vespalib::string(argv[1]) == "--probe-types")) { + return probe_types(); + } Options opts; for (int i = 2; i < argc; ++i) { opts.add_option(argv[i]); |