aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2022-02-09 15:49:05 +0000
committerHåvard Pettersen <havardpe@oath.com>2022-02-09 15:49:05 +0000
commit8eae2d45b4f03eb8f696bc1c13ba0a5768a38f12 (patch)
treeacf9267c3ee890c877b84f6b1ad02cd8683ec17d /eval/src
parente4dd5bce73e2ac856d359098c27195e5aebbbf8b (diff)
handle exceptions caused by model probing
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/vespa/eval/onnx/onnx_wrapper.cpp49
1 files changed, 27 insertions, 22 deletions
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
index 7a6f89ed53e..54c4b863e35 100644
--- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
@@ -302,28 +302,33 @@ Onnx::WirePlanner::need_model_probe(const Onnx &model) const
void
Onnx::WirePlanner::do_model_probe(const Onnx &model)
{
- std::vector<Ort::Value> param_values;
- param_values.reserve(model.inputs().size());
- for (const auto &input: model.inputs()) {
- const auto &pos = _input_types.find(input.name);
- assert(pos != _input_types.end());
- auto vespa_type = pos->second;
- auto sizes = extract_sizes(vespa_type);
- size_t num_cells = vespa_type.dense_subspace_size();
- param_values.push_back(create_empty_onnx_tensor(input.elements, sizes, num_cells, _alloc));
- }
- std::vector<Ort::Value> result_values;
- result_values.reserve(model.outputs().size());
- for (size_t i = 0; i < model.outputs().size(); ++i) {
- result_values.emplace_back(nullptr);
- }
- Ort::RunOptions run_opts(nullptr);
- Ort::Session &session = const_cast<Ort::Session&>(model._session);
- session.Run(run_opts,
- model._input_name_refs.data(), param_values.data(), param_values.size(),
- model._output_name_refs.data(), result_values.data(), result_values.size());
- for (size_t i = 0; i < model.outputs().size(); ++i) {
- _output_types.emplace(model.outputs()[i].name, get_type_of(result_values[i]));
+ try {
+ std::vector<Ort::Value> param_values;
+ param_values.reserve(model.inputs().size());
+ for (const auto &input: model.inputs()) {
+ const auto &pos = _input_types.find(input.name);
+ assert(pos != _input_types.end());
+ auto vespa_type = pos->second;
+ auto sizes = extract_sizes(vespa_type);
+ size_t num_cells = vespa_type.dense_subspace_size();
+ param_values.push_back(create_empty_onnx_tensor(input.elements, sizes, num_cells, _alloc));
+ }
+ std::vector<Ort::Value> result_values;
+ result_values.reserve(model.outputs().size());
+ for (size_t i = 0; i < model.outputs().size(); ++i) {
+ result_values.emplace_back(nullptr);
+ }
+ Ort::RunOptions run_opts(nullptr);
+ Ort::Session &session = const_cast<Ort::Session&>(model._session);
+ session.Run(run_opts,
+ model._input_name_refs.data(), param_values.data(), param_values.size(),
+ model._output_name_refs.data(), result_values.data(), result_values.size());
+ for (size_t i = 0; i < model.outputs().size(); ++i) {
+ _output_types.emplace(model.outputs()[i].name, get_type_of(result_values[i]));
+ }
+ } catch (const Ort::Exception &ex) {
+ _output_types.clear();
+ LOG(warning, "model probe failed: %s", ex.what());
}
}