diff options
-rw-r--r-- | eval/src/vespa/eval/onnx/onnx_wrapper.cpp | 49 |
1 files changed, 27 insertions, 22 deletions
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp index 7a6f89ed53e..54c4b863e35 100644 --- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp @@ -302,28 +302,33 @@ Onnx::WirePlanner::need_model_probe(const Onnx &model) const void Onnx::WirePlanner::do_model_probe(const Onnx &model) { - std::vector<Ort::Value> param_values; - param_values.reserve(model.inputs().size()); - for (const auto &input: model.inputs()) { - const auto &pos = _input_types.find(input.name); - assert(pos != _input_types.end()); - auto vespa_type = pos->second; - auto sizes = extract_sizes(vespa_type); - size_t num_cells = vespa_type.dense_subspace_size(); - param_values.push_back(create_empty_onnx_tensor(input.elements, sizes, num_cells, _alloc)); - } - std::vector<Ort::Value> result_values; - result_values.reserve(model.outputs().size()); - for (size_t i = 0; i < model.outputs().size(); ++i) { - result_values.emplace_back(nullptr); - } - Ort::RunOptions run_opts(nullptr); - Ort::Session &session = const_cast<Ort::Session&>(model._session); - session.Run(run_opts, - model._input_name_refs.data(), param_values.data(), param_values.size(), - model._output_name_refs.data(), result_values.data(), result_values.size()); - for (size_t i = 0; i < model.outputs().size(); ++i) { - _output_types.emplace(model.outputs()[i].name, get_type_of(result_values[i])); + try { + std::vector<Ort::Value> param_values; + param_values.reserve(model.inputs().size()); + for (const auto &input: model.inputs()) { + const auto &pos = _input_types.find(input.name); + assert(pos != _input_types.end()); + auto vespa_type = pos->second; + auto sizes = extract_sizes(vespa_type); + size_t num_cells = vespa_type.dense_subspace_size(); + param_values.push_back(create_empty_onnx_tensor(input.elements, sizes, num_cells, _alloc)); + } + std::vector<Ort::Value> result_values; + result_values.reserve(model.outputs().size()); + for (size_t i = 0; i < model.outputs().size(); ++i) { + result_values.emplace_back(nullptr); + } + Ort::RunOptions run_opts(nullptr); + Ort::Session &session = const_cast<Ort::Session&>(model._session); + session.Run(run_opts, + model._input_name_refs.data(), param_values.data(), param_values.size(), + model._output_name_refs.data(), result_values.data(), result_values.size()); + for (size_t i = 0; i < model.outputs().size(); ++i) { + _output_types.emplace(model.outputs()[i].name, get_type_of(result_values[i])); + } + } catch (const Ort::Exception &ex) { + _output_types.clear(); + LOG(warning, "model probe failed: %s", ex.what()); } } |