diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-09-06 13:51:59 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-09-06 15:14:36 +0000 |
commit | ce319c6c3dffd5827bf9450b87629270f86e8511 (patch) | |
tree | b6a578dff0ad2e8c1aa44d13539fe1cc3cda1c1d /eval | |
parent | 466bc196eee4571e2197624f17b8a7d8aee38cf0 (diff) |
use functions, not objects
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp | 140 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/onnx_wrapper.h | 32 |
2 files changed, 81 insertions, 91 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp index 7c29b20f2f4..28e35827a29 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp @@ -26,8 +26,17 @@ using vespalib::make_string_short::fmt; namespace vespalib::tensor { -using ParamBinder = Onnx::EvalContext::ParamBinder; -using EvalHook = Onnx::EvalContext::EvalHook; +struct Onnx::EvalContext::SelectAdaptParam { + template <typename ...Ts> static auto invoke() { return adapt_param<Ts...>; } +}; + +struct Onnx::EvalContext::SelectConvertParam { + template <typename ...Ts> static auto invoke() { return convert_param<Ts...>; } +}; + +struct Onnx::EvalContext::SelectConvertResult { + template <typename ...Ts> static auto invoke() { return convert_result<Ts...>; } +}; namespace { @@ -84,68 +93,6 @@ struct CreateVespaTensor { } }; -template <typename T> -struct ParamAdapter : ParamBinder { - const Onnx::TensorType &type; - const Ort::MemoryInfo &memory; - ParamAdapter(const Onnx::TensorType &type_in, const Ort::MemoryInfo &memory_in) - : type(type_in), memory(memory_in) {} - void bind(const eval::Value &vespa, Ort::Value &onnx) override { - const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef(); - auto cells = unconstify(cells_ref.typify<T>()); - onnx = Ort::Value::CreateTensor<T>(memory, cells.begin(), cells.size(), type.dimensions.data(), type.dimensions.size()); - } -}; - -struct CreateParamAdapter { - template <typename T> static ParamBinder::UP invoke(const Onnx::TensorType &type, const Ort::MemoryInfo &memory) { - return std::make_unique<ParamAdapter<T>>(type, memory); - } -}; - -template <typename SRC, typename DST> -struct ParamConverter : ParamBinder { - void bind(const eval::Value &vespa, Ort::Value &onnx) override { - auto cells = static_cast<const DenseTensorView &>(vespa).cellsRef().typify<SRC>(); - size_t n = cells.size(); - const SRC *src = cells.begin(); - DST *dst = onnx.GetTensorMutableData<DST>(); - for (size_t i = 0; i < n; ++i) { - dst[i] = DST(src[i]); - } - } -}; - -struct CreateParamConverter { - template <typename SRC, typename DST> static ParamBinder::UP invoke() { - return std::make_unique<ParamConverter<SRC,DST>>(); - } -}; - -template <typename SRC, typename DST> -struct ResultConverter : EvalHook { - Ort::Value &onnx; - const eval::Value &vespa; - ResultConverter(Ort::Value &onnx_in, const eval::Value &vespa_in) - : onnx(onnx_in), vespa(vespa_in) {} - void invoke() override { - const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef(); - auto cells = unconstify(cells_ref.typify<DST>()); - size_t n = cells.size(); - DST *dst = cells.begin(); - const SRC *src = onnx.GetTensorMutableData<SRC>(); - for (size_t i = 0; i < n; ++i) { - dst[i] = DST(src[i]); - } - } -}; - -struct CreateResultConverter { - template <typename SRC, typename DST> static EvalHook::UP invoke(Ort::Value &onnx, const eval::Value &vespa) { - return std::make_unique<ResultConverter<SRC,DST>>(onnx, vespa); - } -}; - //----------------------------------------------------------------------------- template <typename E> vespalib::string type_name(E enum_value) { @@ -168,16 +115,18 @@ eval::Value::UP create_vespa_tensor(const eval::ValueType &type) { return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type); } -ParamBinder::UP create_param_adapter(eval::ValueType::CellType ct, const Onnx::TensorType &type, const Ort::MemoryInfo &memory) { - return typify_invoke<1,MyTypify,CreateParamAdapter>(ct, type, memory); +//----------------------------------------------------------------------------- + +auto select_adapt_param(eval::ValueType::CellType ct) { + return typify_invoke<1,MyTypify,Onnx::EvalContext::SelectAdaptParam>(ct); } -ParamBinder::UP create_param_converter(eval::ValueType::CellType ct, Onnx::ElementType et) { - return typify_invoke<2,MyTypify,CreateParamConverter>(ct, et); +auto select_convert_param(eval::ValueType::CellType ct, Onnx::ElementType et) { + return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertParam>(ct, et); } -EvalHook::UP create_result_converter(Onnx::ElementType et, Ort::Value &onnx, const eval::Value &vespa) { - return typify_invoke<2,MyTypify,CreateResultConverter>(et, vespa.type().cell_type(), onnx, vespa); +auto select_convert_result(Onnx::ElementType et, eval::ValueType::CellType ct) { + return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertResult>(et, ct); } //----------------------------------------------------------------------------- @@ -410,6 +359,43 @@ Onnx::WirePlanner::get_wire_info(const Onnx &model) const Ort::AllocatorWithDefaultOptions Onnx::EvalContext::_alloc; +template <typename T> +void +Onnx::EvalContext::adapt_param(EvalContext &self, size_t idx, const eval::Value ¶m) +{ + const auto &cells_ref = static_cast<const DenseTensorView &>(param).cellsRef(); + auto cells = unconstify(cells_ref.typify<T>()); + const auto &sizes = self._wire_info.onnx_inputs[idx].dimensions; + self._param_values[idx] = Ort::Value::CreateTensor<T>(self._cpu_memory, cells.begin(), cells.size(), sizes.data(), sizes.size()); +} + +template <typename SRC, typename DST> +void +Onnx::EvalContext::convert_param(EvalContext &self, size_t idx, const eval::Value ¶m) +{ + auto cells = static_cast<const DenseTensorView &>(param).cellsRef().typify<SRC>(); + size_t n = cells.size(); + const SRC *src = cells.begin(); + DST *dst = self._param_values[idx].GetTensorMutableData<DST>(); + for (size_t i = 0; i < n; ++i) { + dst[i] = DST(src[i]); + } +} + +template <typename SRC, typename DST> +void +Onnx::EvalContext::convert_result(EvalContext &self, size_t idx) +{ + const auto &cells_ref = static_cast<const DenseTensorView &>(*self._results[idx]).cellsRef(); + auto cells = unconstify(cells_ref.typify<DST>()); + size_t n = cells.size(); + DST *dst = cells.begin(); + const SRC *src = self._result_values[idx].GetTensorMutableData<SRC>(); + for (size_t i = 0; i < n; ++i) { + dst[i] = DST(src[i]); + } +} + Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) : _model(model), _wire_info(wire_info), @@ -418,7 +404,7 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) _result_values(), _results(), _param_binders(), - _eval_hooks() + _result_converters() { assert(_wire_info.vespa_inputs.size() == _model.inputs().size()); assert(_wire_info.onnx_inputs.size() == _model.inputs().size()); @@ -433,10 +419,10 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) const auto &onnx = _wire_info.onnx_inputs[i]; if (is_same_type(vespa.cell_type(), onnx.elements)) { _param_values.push_back(Ort::Value(nullptr)); - _param_binders.push_back(create_param_adapter(vespa.cell_type(), onnx, _cpu_memory)); + _param_binders.push_back(select_adapt_param(vespa.cell_type())); } else { _param_values.push_back(create_onnx_tensor(onnx, _alloc)); - _param_binders.push_back(create_param_converter(vespa.cell_type(), onnx.elements)); + _param_binders.push_back(select_convert_param(vespa.cell_type(), onnx.elements)); } } for (size_t i = 0; i < _model.outputs().size(); ++i) { @@ -447,7 +433,7 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) _results.push_back(create_vespa_tensor_ref(vespa, _result_values.back())); } else { _results.push_back(create_vespa_tensor(vespa)); - _eval_hooks.push_back(create_result_converter(onnx.elements, _result_values.back(), *_results.back().get())); + _result_converters.emplace_back(i, select_convert_result(onnx.elements, vespa.cell_type())); } } // make sure references to Ort::Value inside _result_values are safe @@ -459,7 +445,7 @@ Onnx::EvalContext::~EvalContext() = default; void Onnx::EvalContext::bind_param(size_t i, const eval::Value ¶m) { - _param_binders[i]->bind(param, _param_values[i]); + _param_binders[i](*this, i, param); } void @@ -470,8 +456,8 @@ Onnx::EvalContext::eval() session.Run(run_opts, _model._input_name_refs.data(), _param_values.data(), _param_values.size(), _model._output_name_refs.data(), _result_values.data(), _result_values.size()); - for (const auto &hook: _eval_hooks) { - hook->invoke(); + for (const auto &entry: _result_converters) { + entry.second(*this, entry.first); } } diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h index 4d2ef6ba50d..85a59824229 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h @@ -95,19 +95,10 @@ public: // all parameter values are expected to be bound per evaluation // output values are pre-allocated and will not change class EvalContext { - public: - struct ParamBinder { - using UP = std::unique_ptr<ParamBinder>; - virtual void bind(const eval::Value &vespa, Ort::Value &onnx) = 0; - virtual ~ParamBinder() {} - }; - struct EvalHook { - using UP = std::unique_ptr<EvalHook>; - virtual void invoke() = 0; - virtual ~EvalHook() {} - }; - private: + using param_fun_t = void (*)(EvalContext &, size_t i, const eval::Value &); + using result_fun_t = void (*)(EvalContext &, size_t i); + static Ort::AllocatorWithDefaultOptions _alloc; const Onnx &_model; @@ -116,10 +107,23 @@ public: std::vector<Ort::Value> _param_values; std::vector<Ort::Value> _result_values; std::vector<eval::Value::UP> _results; - std::vector<ParamBinder::UP> _param_binders; - std::vector<EvalHook::UP> _eval_hooks; + std::vector<param_fun_t> _param_binders; + std::vector<std::pair<size_t,result_fun_t>> _result_converters; + + template <typename T> + static void adapt_param(EvalContext &self, size_t idx, const eval::Value ¶m); + + template <typename SRC, typename DST> + static void convert_param(EvalContext &self, size_t idx, const eval::Value ¶m); + + template <typename SRC, typename DST> + static void convert_result(EvalContext &self, size_t idx); public: + struct SelectAdaptParam; + struct SelectConvertParam; + struct SelectConvertResult; + EvalContext(const Onnx &model, const WireInfo &wire_info); ~EvalContext(); size_t num_params() const { return _param_values.size(); } |