From d920b9b84b2e9abfa053e6215b6bd55045dab140 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Fri, 11 Feb 2022 12:03:06 +0000 Subject: enable probing output types also added some testing --- eval/CMakeLists.txt | 1 + eval/src/apps/analyze_onnx_model/CMakeLists.txt | 3 +- .../apps/analyze_onnx_model/analyze_onnx_model.cpp | 44 ++++++- .../tests/apps/analyze_onnx_model/CMakeLists.txt | 9 ++ .../analyze_onnx_model/analyze_onnx_model_test.cpp | 137 +++++++++++++++++++++ 5 files changed, 192 insertions(+), 2 deletions(-) create mode 100644 eval/src/tests/apps/analyze_onnx_model/CMakeLists.txt create mode 100644 eval/src/tests/apps/analyze_onnx_model/analyze_onnx_model_test.cpp diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index eed4fa5ce66..e6669e3fde8 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -12,6 +12,7 @@ vespa_define_module( TESTS src/tests/ann + src/tests/apps/analyze_onnx_model src/tests/apps/eval_expr src/tests/eval/addr_to_symbol src/tests/eval/aggr diff --git a/eval/src/apps/analyze_onnx_model/CMakeLists.txt b/eval/src/apps/analyze_onnx_model/CMakeLists.txt index e2ed64cd8cc..dc89213f9eb 100644 --- a/eval/src/apps/analyze_onnx_model/CMakeLists.txt +++ b/eval/src/apps/analyze_onnx_model/CMakeLists.txt @@ -1,7 +1,8 @@ # Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_executable(vespa-analyze-onnx-model +vespa_add_executable(eval_analyze_onnx_model_app SOURCES analyze_onnx_model.cpp + OUTPUT_NAME vespa-analyze-onnx-model INSTALL bin DEPENDS vespaeval diff --git a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp index 2f22f903f2e..506073ae8b3 100644 --- a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp +++ b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -11,8 +12,13 @@ using vespalib::make_string_short::fmt; +using vespalib::Slime; +using vespalib::slime::JsonFormat; +using vespalib::slime::Inspector; +using vespalib::slime::Cursor; using vespalib::FilePointer; using namespace vespalib::eval; +using namespace vespalib::eval::test; bool read_line(FilePointer &file, vespalib::string &line) { char line_buffer[1024]; @@ -169,14 +175,50 @@ int usage(const char *self) { fprintf(stderr, " load onnx model and report memory usage\n"); fprintf(stderr, " options are used to specify unknown values, like dimension sizes\n"); fprintf(stderr, " options are accepted in the order in which they are needed\n"); - fprintf(stderr, " tip: run without options first, to see which you need\n"); + fprintf(stderr, " tip: run without options first, to see which you need\n\n"); + fprintf(stderr, "usage: %s --probe-types\n", self); + fprintf(stderr, " use onnx model to infer/probe output types based on input types\n"); + fprintf(stderr, " parameters are read from stdin and results are written to stdout\n"); + fprintf(stderr, " input format (json): {model:, inputs:{:vespa-type-string}}\n"); + fprintf(stderr, " output format (json): {outputs:{:vespa-type-string}}\n"); return 1; } +int probe_types() { + StdIn std_in; + StdOut std_out; + Slime params; + if (!JsonFormat::decode(std_in, params)) { + return 3; + } + Slime result; + auto &root = result.setObject(); + auto &types = root.setObject("outputs"); + Onnx model(params["model"].asString().make_string(), Onnx::Optimize::DISABLE); + Onnx::WirePlanner planner; + for (size_t i = 0; i < model.inputs().size(); ++i) { + auto spec = params["inputs"][model.inputs()[i].name].asString().make_string(); + auto input_type = ValueType::from_spec(spec); + REQUIRE(!input_type.is_error()); + REQUIRE(planner.bind_input_type(input_type, model.inputs()[i])); + } + planner.prepare_output_types(model); + for (const auto &output: model.outputs()) { + auto output_type = planner.make_output_type(output); + REQUIRE(!output_type.is_error()); + types.setString(output.name, output_type.to_spec()); + } + write_compact(result, std_out); + return 0; +} + int my_main(int argc, char **argv) { if (argc < 2) { return usage(argv[0]); } + if ((argc == 2) && (vespalib::string(argv[1]) == "--probe-types")) { + return probe_types(); + } Options opts; for (int i = 2; i < argc; ++i) { opts.add_option(argv[i]); diff --git a/eval/src/tests/apps/analyze_onnx_model/CMakeLists.txt b/eval/src/tests/apps/analyze_onnx_model/CMakeLists.txt new file mode 100644 index 00000000000..7b70360a622 --- /dev/null +++ b/eval/src/tests/apps/analyze_onnx_model/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_analyze_onnx_model_test_app TEST + SOURCES + analyze_onnx_model_test.cpp + DEPENDS + vespaeval +) +vespa_add_test(NAME eval_analyze_onnx_model_test_app COMMAND eval_analyze_onnx_model_test_app + DEPENDS eval_analyze_onnx_model_test_app eval_analyze_onnx_model_app) diff --git a/eval/src/tests/apps/analyze_onnx_model/analyze_onnx_model_test.cpp b/eval/src/tests/apps/analyze_onnx_model/analyze_onnx_model_test.cpp new file mode 100644 index 00000000000..2c1b2b21b9e --- /dev/null +++ b/eval/src/tests/apps/analyze_onnx_model/analyze_onnx_model_test.cpp @@ -0,0 +1,137 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace vespalib; +using namespace vespalib::eval::test; +using vespalib::make_string_short::fmt; +using vespalib::slime::JsonFormat; +using vespalib::slime::Inspector; + +vespalib::string module_build_path("../../../../"); +vespalib::string binary = module_build_path + "src/apps/analyze_onnx_model/vespa-analyze-onnx-model"; +vespalib::string probe_cmd = binary + " --probe-types"; + +std::string get_source_dir() { + const char *dir = getenv("SOURCE_DIRECTORY"); + return (dir ? dir : "."); +} +std::string source_dir = get_source_dir(); +std::string guess_batch_model = source_dir + "/../../tensor/onnx_wrapper/guess_batch.onnx"; + +//----------------------------------------------------------------------------- + +void read_until_eof(Input &input) { + for (auto mem = input.obtain(); mem.size > 0; mem = input.obtain()) { + input.evict(mem.size); + } +} + +// Output adapter used to write to stdin of a child process +class ChildIn : public Output { + ChildProcess &_child; + SimpleBuffer _output; +public: + ChildIn(ChildProcess &child) : _child(child) {} + WritableMemory reserve(size_t bytes) override { + return _output.reserve(bytes); + } + Output &commit(size_t bytes) override { + _output.commit(bytes); + Memory buf = _output.obtain(); + ASSERT_TRUE(_child.write(buf.data, buf.size)); + _output.evict(buf.size); + return *this; + } +}; + +// Input adapter used to read from stdout of a child process +class ChildOut : public Input { + ChildProcess &_child; + SimpleBuffer _input; +public: + ChildOut(ChildProcess &child) + : _child(child) + { + EXPECT_TRUE(_child.running()); + EXPECT_TRUE(!_child.failed()); + } + Memory obtain() override { + if ((_input.get().size == 0) && !_child.eof()) { + WritableMemory buf = _input.reserve(4_Ki); + uint32_t res = _child.read(buf.data, buf.size); + ASSERT_TRUE((res > 0) || _child.eof()); + _input.commit(res); + } + return _input.obtain(); + } + Input &evict(size_t bytes) override { + _input.evict(bytes); + return *this; + } +}; + +//----------------------------------------------------------------------------- + +void dump_message(const char *prefix, const Slime &slime) { + SimpleBuffer buf; + slime::JsonFormat::encode(slime, buf, true); + auto str = buf.get().make_string(); + fprintf(stderr, "%s%s\n", prefix, str.c_str()); +} + +class Server { +private: + TimeBomb _bomb; + ChildProcess _child; + ChildIn _child_stdin; + ChildOut _child_stdout; +public: + Server(vespalib::string cmd) + : _bomb(60), + _child(cmd.c_str()), + _child_stdin(_child), + _child_stdout(_child) {} + ~Server(); + Slime invoke(const Slime &req) { + dump_message("request --> ", req); + write_compact(req, _child_stdin); + Slime reply; + ASSERT_TRUE(JsonFormat::decode(_child_stdout, reply)); + dump_message(" reply <-- ", reply); + return reply; + } +}; +Server::~Server() { + _child.close(); + read_until_eof(_child_stdout); + ASSERT_TRUE(_child.wait()); + ASSERT_TRUE(!_child.running()); + ASSERT_TRUE(!_child.failed()); +} + +//----------------------------------------------------------------------------- + +TEST_F("require that output types can be probed", Server(probe_cmd)) { + Slime params; + params.setObject(); + params.get().setString("model", guess_batch_model); + params.get().setObject("inputs"); + params["inputs"].setString("in1", "tensor(x[3])"); + params["inputs"].setString("in2", "tensor(x[3])"); + Slime result = f1.invoke(params); + EXPECT_EQUAL(result["outputs"]["out"].asString().make_string(), vespalib::string("tensor(d0[3])")); +} + +//----------------------------------------------------------------------------- + +TEST_MAIN_WITH_PROCESS_PROXY() { TEST_RUN_ALL(); } -- cgit v1.2.3