1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "eval_onnx.h"
#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/value_codec.h>
#include <vespa/log/log.h>
LOG_SETUP(".eval.eval.test.eval_onnx");
namespace vespalib::eval::test {
std::vector<TensorSpec> eval_onnx(const Onnx &model, const std::vector<TensorSpec> ¶ms) {
if (params.size() != model.inputs().size()) {
LOG(error, "model with %zu inputs run with %zu parameters", model.inputs().size(), params.size());
return {}; // wrong number of parameters
}
Onnx::WirePlanner planner;
for (size_t i = 0; i < model.inputs().size(); ++i) {
if (!planner.bind_input_type(ValueType::from_spec(params[i].type()), model.inputs()[i])) {
LOG(error, "unable to bind input type: %s -> %s", params[i].type().c_str(), model.inputs()[i].type_as_string().c_str());
return {}; // inconsistent input types
}
}
planner.prepare_output_types(model);
for (size_t i = 0; i < model.outputs().size(); ++i) {
if (planner.make_output_type(model.outputs()[i]).is_error()) {
LOG(error, "unable to make output type: %s -> error", model.outputs()[i].type_as_string().c_str());
return {}; // unable to infer/probe output type
}
}
planner.prepare_output_types(model);
auto wire_info = planner.get_wire_info(model);
try {
Onnx::EvalContext context(model, wire_info);
std::vector<Value::UP> inputs;
for (const auto ¶m: params) {
inputs.push_back(value_from_spec(param, FastValueBuilderFactory::get()));
}
for (size_t i = 0; i < model.inputs().size(); ++i) {
context.bind_param(i, *inputs[i]);
}
context.eval();
std::vector<TensorSpec> results;
for (size_t i = 0; i < model.outputs().size(); ++i) {
results.push_back(spec_from_value(context.get_result(i)));
}
return results;
} catch (const Ort::Exception &ex) {
LOG(error, "model run failed: %s", ex.what());
return {}; // evaluation failed
}
}
} // namespace
|