summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/apps/analyze_onnx_model/analyze_onnx_model_test.cpp
blob: 2c1b2b21b9e14eacb4c5043bae381bfc72bb111e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/testkit/time_bomb.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/util/child_process.h>
#include <vespa/vespalib/data/input.h>
#include <vespa/vespalib/data/output.h>
#include <vespa/vespalib/data/simple_buffer.h>
#include <vespa/vespalib/util/size_literals.h>
#include <vespa/eval/eval/test/test_io.h>

using namespace vespalib;
using namespace vespalib::eval::test;
using vespalib::make_string_short::fmt;
using vespalib::slime::JsonFormat;
using vespalib::slime::Inspector;

vespalib::string module_build_path("../../../../");
vespalib::string binary = module_build_path + "src/apps/analyze_onnx_model/vespa-analyze-onnx-model";
vespalib::string probe_cmd = binary + " --probe-types";

std::string get_source_dir() {
    const char *dir = getenv("SOURCE_DIRECTORY");
    return (dir ? dir : ".");
}
std::string source_dir = get_source_dir();
std::string guess_batch_model = source_dir + "/../../tensor/onnx_wrapper/guess_batch.onnx";

//-----------------------------------------------------------------------------

void read_until_eof(Input &input) {
    for (auto mem = input.obtain(); mem.size > 0; mem = input.obtain()) {
        input.evict(mem.size);
    }
}

// Output adapter used to write to stdin of a child process
class ChildIn : public Output {
    ChildProcess &_child;
    SimpleBuffer _output;
public:
    ChildIn(ChildProcess &child) : _child(child) {}
    WritableMemory reserve(size_t bytes) override {
        return _output.reserve(bytes);
    }
    Output &commit(size_t bytes) override {
        _output.commit(bytes);
        Memory buf = _output.obtain();
        ASSERT_TRUE(_child.write(buf.data, buf.size));
        _output.evict(buf.size);
        return *this;
    }
};

// Input adapter used to read from stdout of a child process
class ChildOut : public Input {
    ChildProcess &_child;
    SimpleBuffer _input;
public:
    ChildOut(ChildProcess &child)
      : _child(child)
    {
        EXPECT_TRUE(_child.running());
        EXPECT_TRUE(!_child.failed());
    }
    Memory obtain() override {
        if ((_input.get().size == 0) && !_child.eof()) {
            WritableMemory buf = _input.reserve(4_Ki);
            uint32_t res = _child.read(buf.data, buf.size);
            ASSERT_TRUE((res > 0) || _child.eof());
            _input.commit(res);
        }
        return _input.obtain();
    }
    Input &evict(size_t bytes) override {
        _input.evict(bytes);
        return *this;
    }
};

//-----------------------------------------------------------------------------

void dump_message(const char *prefix, const Slime &slime) {
    SimpleBuffer buf;
    slime::JsonFormat::encode(slime, buf, true);
    auto str = buf.get().make_string();
    fprintf(stderr, "%s%s\n", prefix, str.c_str());
}

class Server {
private:
    TimeBomb _bomb;
    ChildProcess _child;
    ChildIn _child_stdin;
    ChildOut _child_stdout;
public:
    Server(vespalib::string cmd)
      : _bomb(60),
        _child(cmd.c_str()),
        _child_stdin(_child),
        _child_stdout(_child) {}
    ~Server();
    Slime invoke(const Slime &req) {
        dump_message("request --> ", req);
        write_compact(req, _child_stdin);
        Slime reply;
        ASSERT_TRUE(JsonFormat::decode(_child_stdout, reply));
        dump_message("  reply <-- ", reply);
        return reply;
    }
};
Server::~Server() {
    _child.close();
    read_until_eof(_child_stdout);
    ASSERT_TRUE(_child.wait());
    ASSERT_TRUE(!_child.running());
    ASSERT_TRUE(!_child.failed());
}

//-----------------------------------------------------------------------------

TEST_F("require that output types can be probed", Server(probe_cmd)) {
    Slime params;
    params.setObject();
    params.get().setString("model", guess_batch_model);
    params.get().setObject("inputs");
    params["inputs"].setString("in1", "tensor<float>(x[3])");
    params["inputs"].setString("in2", "tensor<float>(x[3])");
    Slime result = f1.invoke(params);
    EXPECT_EQUAL(result["outputs"]["out"].asString().make_string(), vespalib::string("tensor<float>(d0[3])"));
}

//-----------------------------------------------------------------------------

TEST_MAIN_WITH_PROCESS_PROXY() { TEST_RUN_ALL(); }