summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2023-09-04 11:40:48 +0000
committerHenning Baldersheim <balder@yahoo-inc.com>2023-09-04 11:40:48 +0000
commitfd8d29f8e1ebb51decaccee71aa6119275c11be1 (patch)
tree9330d2ca829717fe27a53344ddc9a9e9bfc1f4bb /eval
parentc3ee6199624f99eb695a0d397b800a0fd6ab326d (diff)
Unify and modernize code and layout
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/onnx/onnx_wrapper.cpp32
1 files changed, 16 insertions, 16 deletions
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
index 8f9450c2660..2490457cb1d 100644
--- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
@@ -8,10 +8,6 @@
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/typify.h>
#include <vespa/vespalib/util/classname.h>
-#include <assert.h>
-#include <cmath>
-#include <stdlib.h>
-#include <stdio.h>
#include <type_traits>
#include <vespa/log/log.h>
@@ -171,20 +167,21 @@ private:
public:
OnnxString(const OnnxString &rhs) = delete;
OnnxString &operator=(const OnnxString &rhs) = delete;
- OnnxString(OnnxString &&rhs) = default;
- OnnxString &operator=(OnnxString &&rhs) = default;
+ OnnxString(OnnxString &&rhs) noexcept = default;
+ OnnxString &operator=(OnnxString &&rhs) noexcept = default;
const char *get() const { return _str.get(); }
~OnnxString() = default;
static OnnxString get_input_name(const Ort::Session &session, size_t idx) {
- return OnnxString(session.GetInputNameAllocated(idx, _alloc));
+ return {session.GetInputNameAllocated(idx, _alloc)};
}
static OnnxString get_output_name(const Ort::Session &session, size_t idx) {
- return OnnxString(session.GetOutputNameAllocated(idx, _alloc));
+ return {session.GetOutputNameAllocated(idx, _alloc)};
}
};
Ort::AllocatorWithDefaultOptions OnnxString::_alloc;
-std::vector<Onnx::DimSize> make_dimensions(const Ort::ConstTensorTypeAndShapeInfo &tensor_info) {
+std::vector<Onnx::DimSize>
+make_dimensions(const Ort::ConstTensorTypeAndShapeInfo &tensor_info) {
std::vector<const char *> symbolic_sizes(tensor_info.GetDimensionsCount(), nullptr);
tensor_info.GetSymbolicDimensions(symbolic_sizes.data(), symbolic_sizes.size());
auto shape = tensor_info.GetShape();
@@ -201,13 +198,15 @@ std::vector<Onnx::DimSize> make_dimensions(const Ort::ConstTensorTypeAndShapeInf
return result;
}
-Onnx::TensorInfo make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) {
+Onnx::TensorInfo
+make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) {
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
auto element_type = tensor_info.GetElementType();
return Onnx::TensorInfo{vespalib::string(name.get()), make_dimensions(tensor_info), make_element_type(element_type)};
}
-Onnx::TensorType get_type_of(const Ort::Value &value) {
+Onnx::TensorType
+get_type_of(const Ort::Value &value) {
auto tensor_info = value.GetTensorTypeAndShapeInfo();
auto element_type = tensor_info.GetElementType();
auto shape = tensor_info.GetShape();
@@ -216,10 +215,11 @@ Onnx::TensorType get_type_of(const Ort::Value &value) {
throw Ort::Exception("[onnx wrapper] actual value has unknown dimension size", ORT_FAIL);
}
}
- return Onnx::TensorType(make_element_type(element_type), shape);
+ return {make_element_type(element_type), shape};
}
-std::vector<int64_t> extract_sizes(const ValueType &type) {
+std::vector<int64_t>
+extract_sizes(const ValueType &type) {
std::vector<int64_t> sizes;
for (const auto &dim: type.dimensions()) {
sizes.push_back(dim.size);
@@ -306,7 +306,7 @@ Onnx::WirePlanner::do_model_probe(const Onnx &model)
result_values.emplace_back(nullptr);
}
Ort::RunOptions run_opts(nullptr);
- Ort::Session &session = const_cast<Ort::Session&>(model._session);
+ auto &session = const_cast<Ort::Session&>(model._session);
session.Run(run_opts,
model._input_name_refs.data(), param_values.data(), param_values.size(),
model._output_name_refs.data(), result_values.data(), result_values.size());
@@ -554,7 +554,7 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
const auto &vespa = _wire_info.vespa_inputs[i];
const auto &onnx = _wire_info.onnx_inputs[i];
if (is_same_type(vespa.cell_type(), onnx.elements)) {
- _param_values.push_back(Ort::Value(nullptr));
+ _param_values.emplace_back(nullptr);
_param_binders.push_back(SelectAdaptParam()(vespa.cell_type()));
} else {
_param_values.push_back(CreateOnnxTensor()(onnx, _alloc));
@@ -587,7 +587,7 @@ Onnx::EvalContext::bind_param(size_t i, const Value &param)
void
Onnx::EvalContext::eval()
{
- Ort::Session &session = const_cast<Ort::Session&>(_model._session);
+ auto &session = const_cast<Ort::Session&>(_model._session);
Ort::RunOptions run_opts(nullptr);
session.Run(run_opts,
_model._input_name_refs.data(), _param_values.data(), _param_values.size(),