diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2023-09-04 11:40:48 +0000 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2023-09-04 11:40:48 +0000 |
commit | fd8d29f8e1ebb51decaccee71aa6119275c11be1 (patch) | |
tree | 9330d2ca829717fe27a53344ddc9a9e9bfc1f4bb /eval | |
parent | c3ee6199624f99eb695a0d397b800a0fd6ab326d (diff) |
Unify and modernize code and layout
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/onnx/onnx_wrapper.cpp | 32 |
1 files changed, 16 insertions, 16 deletions
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp index 8f9450c2660..2490457cb1d 100644 --- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp @@ -8,10 +8,6 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/typify.h> #include <vespa/vespalib/util/classname.h> -#include <assert.h> -#include <cmath> -#include <stdlib.h> -#include <stdio.h> #include <type_traits> #include <vespa/log/log.h> @@ -171,20 +167,21 @@ private: public: OnnxString(const OnnxString &rhs) = delete; OnnxString &operator=(const OnnxString &rhs) = delete; - OnnxString(OnnxString &&rhs) = default; - OnnxString &operator=(OnnxString &&rhs) = default; + OnnxString(OnnxString &&rhs) noexcept = default; + OnnxString &operator=(OnnxString &&rhs) noexcept = default; const char *get() const { return _str.get(); } ~OnnxString() = default; static OnnxString get_input_name(const Ort::Session &session, size_t idx) { - return OnnxString(session.GetInputNameAllocated(idx, _alloc)); + return {session.GetInputNameAllocated(idx, _alloc)}; } static OnnxString get_output_name(const Ort::Session &session, size_t idx) { - return OnnxString(session.GetOutputNameAllocated(idx, _alloc)); + return {session.GetOutputNameAllocated(idx, _alloc)}; } }; Ort::AllocatorWithDefaultOptions OnnxString::_alloc; -std::vector<Onnx::DimSize> make_dimensions(const Ort::ConstTensorTypeAndShapeInfo &tensor_info) { +std::vector<Onnx::DimSize> +make_dimensions(const Ort::ConstTensorTypeAndShapeInfo &tensor_info) { std::vector<const char *> symbolic_sizes(tensor_info.GetDimensionsCount(), nullptr); tensor_info.GetSymbolicDimensions(symbolic_sizes.data(), symbolic_sizes.size()); auto shape = tensor_info.GetShape(); @@ -201,13 +198,15 @@ std::vector<Onnx::DimSize> make_dimensions(const Ort::ConstTensorTypeAndShapeInf return result; } -Onnx::TensorInfo make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) { +Onnx::TensorInfo +make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); auto element_type = tensor_info.GetElementType(); return Onnx::TensorInfo{vespalib::string(name.get()), make_dimensions(tensor_info), make_element_type(element_type)}; } -Onnx::TensorType get_type_of(const Ort::Value &value) { +Onnx::TensorType +get_type_of(const Ort::Value &value) { auto tensor_info = value.GetTensorTypeAndShapeInfo(); auto element_type = tensor_info.GetElementType(); auto shape = tensor_info.GetShape(); @@ -216,10 +215,11 @@ Onnx::TensorType get_type_of(const Ort::Value &value) { throw Ort::Exception("[onnx wrapper] actual value has unknown dimension size", ORT_FAIL); } } - return Onnx::TensorType(make_element_type(element_type), shape); + return {make_element_type(element_type), shape}; } -std::vector<int64_t> extract_sizes(const ValueType &type) { +std::vector<int64_t> +extract_sizes(const ValueType &type) { std::vector<int64_t> sizes; for (const auto &dim: type.dimensions()) { sizes.push_back(dim.size); @@ -306,7 +306,7 @@ Onnx::WirePlanner::do_model_probe(const Onnx &model) result_values.emplace_back(nullptr); } Ort::RunOptions run_opts(nullptr); - Ort::Session &session = const_cast<Ort::Session&>(model._session); + auto &session = const_cast<Ort::Session&>(model._session); session.Run(run_opts, model._input_name_refs.data(), param_values.data(), param_values.size(), model._output_name_refs.data(), result_values.data(), result_values.size()); @@ -554,7 +554,7 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) const auto &vespa = _wire_info.vespa_inputs[i]; const auto &onnx = _wire_info.onnx_inputs[i]; if (is_same_type(vespa.cell_type(), onnx.elements)) { - _param_values.push_back(Ort::Value(nullptr)); + _param_values.emplace_back(nullptr); _param_binders.push_back(SelectAdaptParam()(vespa.cell_type())); } else { _param_values.push_back(CreateOnnxTensor()(onnx, _alloc)); @@ -587,7 +587,7 @@ Onnx::EvalContext::bind_param(size_t i, const Value ¶m) void Onnx::EvalContext::eval() { - Ort::Session &session = const_cast<Ort::Session&>(_model._session); + auto &session = const_cast<Ort::Session&>(_model._session); Ort::RunOptions run_opts(nullptr); session.Run(run_opts, _model._input_name_refs.data(), _param_values.data(), _param_values.size(), |