diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2020-09-07 14:11:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-07 14:11:01 +0200 |
commit | fa1486543abc712fa816d81be97212813a07faf6 (patch) | |
tree | 5c00ac341714db1baffa7f918f1e32fec5888ef1 /eval | |
parent | db9105023dc06c7ca56a2914735dba663bd21d5c (diff) | |
parent | 7dc8177525bb90a5782d2c2c84155ec3adbe7adc (diff) |
Merge pull request #14303 from vespa-engine/havardpe/use-functions-not-objects
use functions, not objects
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp | 178 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/onnx_wrapper.h | 32 |
2 files changed, 97 insertions, 113 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp index 7c29b20f2f4..8f1b01a58ab 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp @@ -26,9 +26,6 @@ using vespalib::make_string_short::fmt; namespace vespalib::tensor { -using ParamBinder = Onnx::EvalContext::ParamBinder; -using EvalHook = Onnx::EvalContext::EvalHook; - namespace { struct TypifyOnnxElementType { @@ -66,6 +63,9 @@ struct CreateOnnxTensor { template <typename T> static Ort::Value invoke(const std::vector<int64_t> &sizes, OrtAllocator *alloc) { return Ort::Value::CreateTensor<T>(alloc, sizes.data(), sizes.size()); } + Ort::Value operator()(const Onnx::TensorType &type, OrtAllocator *alloc) { + return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc); + } }; struct CreateVespaTensorRef { @@ -74,6 +74,9 @@ struct CreateVespaTensorRef { ConstArrayRef<T> cells(value.GetTensorMutableData<T>(), num_cells); return std::make_unique<DenseTensorView>(type_ref, TypedCells(cells)); } + eval::Value::UP operator()(const eval::ValueType &type_ref, Ort::Value &value) { + return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value); + } }; struct CreateVespaTensor { @@ -82,67 +85,8 @@ struct CreateVespaTensor { std::vector<T> cells(num_cells, T{}); return std::make_unique<DenseTensor<T>>(type, std::move(cells)); } -}; - -template <typename T> -struct ParamAdapter : ParamBinder { - const Onnx::TensorType &type; - const Ort::MemoryInfo &memory; - ParamAdapter(const Onnx::TensorType &type_in, const Ort::MemoryInfo &memory_in) - : type(type_in), memory(memory_in) {} - void bind(const eval::Value &vespa, Ort::Value &onnx) override { - const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef(); - auto cells = unconstify(cells_ref.typify<T>()); - onnx = Ort::Value::CreateTensor<T>(memory, cells.begin(), cells.size(), type.dimensions.data(), type.dimensions.size()); - } -}; - -struct CreateParamAdapter { - template <typename T> static ParamBinder::UP invoke(const Onnx::TensorType &type, const Ort::MemoryInfo &memory) { - return std::make_unique<ParamAdapter<T>>(type, memory); - } -}; - -template <typename SRC, typename DST> -struct ParamConverter : ParamBinder { - void bind(const eval::Value &vespa, Ort::Value &onnx) override { - auto cells = static_cast<const DenseTensorView &>(vespa).cellsRef().typify<SRC>(); - size_t n = cells.size(); - const SRC *src = cells.begin(); - DST *dst = onnx.GetTensorMutableData<DST>(); - for (size_t i = 0; i < n; ++i) { - dst[i] = DST(src[i]); - } - } -}; - -struct CreateParamConverter { - template <typename SRC, typename DST> static ParamBinder::UP invoke() { - return std::make_unique<ParamConverter<SRC,DST>>(); - } -}; - -template <typename SRC, typename DST> -struct ResultConverter : EvalHook { - Ort::Value &onnx; - const eval::Value &vespa; - ResultConverter(Ort::Value &onnx_in, const eval::Value &vespa_in) - : onnx(onnx_in), vespa(vespa_in) {} - void invoke() override { - const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef(); - auto cells = unconstify(cells_ref.typify<DST>()); - size_t n = cells.size(); - DST *dst = cells.begin(); - const SRC *src = onnx.GetTensorMutableData<SRC>(); - for (size_t i = 0; i < n; ++i) { - dst[i] = DST(src[i]); - } - } -}; - -struct CreateResultConverter { - template <typename SRC, typename DST> static EvalHook::UP invoke(Ort::Value &onnx, const eval::Value &vespa) { - return std::make_unique<ResultConverter<SRC,DST>>(onnx, vespa); + eval::Value::UP operator()(const eval::ValueType &type) { + return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type); } }; @@ -156,30 +100,6 @@ template <typename E1, typename E2> bool is_same_type(E1 e1, E2 e2) { return typify_invoke<2,MyTypify,IsSameType>(e1, e2); } -Ort::Value create_onnx_tensor(const Onnx::TensorType &type, OrtAllocator *alloc) { - return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc); -} - -eval::Value::UP create_vespa_tensor_ref(const eval::ValueType &type_ref, Ort::Value &value) { - return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value); -} - -eval::Value::UP create_vespa_tensor(const eval::ValueType &type) { - return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type); -} - -ParamBinder::UP create_param_adapter(eval::ValueType::CellType ct, const Onnx::TensorType &type, const Ort::MemoryInfo &memory) { - return typify_invoke<1,MyTypify,CreateParamAdapter>(ct, type, memory); -} - -ParamBinder::UP create_param_converter(eval::ValueType::CellType ct, Onnx::ElementType et) { - return typify_invoke<2,MyTypify,CreateParamConverter>(ct, et); -} - -EvalHook::UP create_result_converter(Onnx::ElementType et, Ort::Value &onnx, const eval::Value &vespa) { - return typify_invoke<2,MyTypify,CreateResultConverter>(et, vespa.type().cell_type(), onnx, vespa); -} - //----------------------------------------------------------------------------- auto convert_optimize(Onnx::Optimize optimize) { @@ -410,6 +330,66 @@ Onnx::WirePlanner::get_wire_info(const Onnx &model) const Ort::AllocatorWithDefaultOptions Onnx::EvalContext::_alloc; +template <typename T> +void +Onnx::EvalContext::adapt_param(EvalContext &self, size_t idx, const eval::Value ¶m) +{ + const auto &cells_ref = static_cast<const DenseTensorView &>(param).cellsRef(); + auto cells = unconstify(cells_ref.typify<T>()); + const auto &sizes = self._wire_info.onnx_inputs[idx].dimensions; + self._param_values[idx] = Ort::Value::CreateTensor<T>(self._cpu_memory, cells.begin(), cells.size(), sizes.data(), sizes.size()); +} + +template <typename SRC, typename DST> +void +Onnx::EvalContext::convert_param(EvalContext &self, size_t idx, const eval::Value ¶m) +{ + auto cells = static_cast<const DenseTensorView &>(param).cellsRef().typify<SRC>(); + size_t n = cells.size(); + const SRC *src = cells.begin(); + DST *dst = self._param_values[idx].GetTensorMutableData<DST>(); + for (size_t i = 0; i < n; ++i) { + dst[i] = DST(src[i]); + } +} + +template <typename SRC, typename DST> +void +Onnx::EvalContext::convert_result(EvalContext &self, size_t idx) +{ + const auto &cells_ref = static_cast<const DenseTensorView &>(*self._results[idx]).cellsRef(); + auto cells = unconstify(cells_ref.typify<DST>()); + size_t n = cells.size(); + DST *dst = cells.begin(); + const SRC *src = self._result_values[idx].GetTensorMutableData<SRC>(); + for (size_t i = 0; i < n; ++i) { + dst[i] = DST(src[i]); + } +} + +struct Onnx::EvalContext::SelectAdaptParam { + template <typename ...Ts> static auto invoke() { return adapt_param<Ts...>; } + auto operator()(eval::ValueType::CellType ct) { + return typify_invoke<1,MyTypify,SelectAdaptParam>(ct); + } +}; + +struct Onnx::EvalContext::SelectConvertParam { + template <typename ...Ts> static auto invoke() { return convert_param<Ts...>; } + auto operator()(eval::ValueType::CellType ct, Onnx::ElementType et) { + return typify_invoke<2,MyTypify,SelectConvertParam>(ct, et); + } +}; + +struct Onnx::EvalContext::SelectConvertResult { + template <typename ...Ts> static auto invoke() { return convert_result<Ts...>; } + auto operator()(Onnx::ElementType et, eval::ValueType::CellType ct) { + return typify_invoke<2,MyTypify,SelectConvertResult>(et, ct); + } +}; + +//----------------------------------------------------------------------------- + Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) : _model(model), _wire_info(wire_info), @@ -418,7 +398,7 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) _result_values(), _results(), _param_binders(), - _eval_hooks() + _result_converters() { assert(_wire_info.vespa_inputs.size() == _model.inputs().size()); assert(_wire_info.onnx_inputs.size() == _model.inputs().size()); @@ -433,21 +413,21 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) const auto &onnx = _wire_info.onnx_inputs[i]; if (is_same_type(vespa.cell_type(), onnx.elements)) { _param_values.push_back(Ort::Value(nullptr)); - _param_binders.push_back(create_param_adapter(vespa.cell_type(), onnx, _cpu_memory)); + _param_binders.push_back(SelectAdaptParam()(vespa.cell_type())); } else { - _param_values.push_back(create_onnx_tensor(onnx, _alloc)); - _param_binders.push_back(create_param_converter(vespa.cell_type(), onnx.elements)); + _param_values.push_back(CreateOnnxTensor()(onnx, _alloc)); + _param_binders.push_back(SelectConvertParam()(vespa.cell_type(), onnx.elements)); } } for (size_t i = 0; i < _model.outputs().size(); ++i) { const auto &vespa = _wire_info.vespa_outputs[i]; const auto &onnx = _wire_info.onnx_outputs[i]; - _result_values.push_back(create_onnx_tensor(onnx, _alloc)); + _result_values.push_back(CreateOnnxTensor()(onnx, _alloc)); if (is_same_type(vespa.cell_type(), onnx.elements)) { - _results.push_back(create_vespa_tensor_ref(vespa, _result_values.back())); + _results.push_back(CreateVespaTensorRef()(vespa, _result_values.back())); } else { - _results.push_back(create_vespa_tensor(vespa)); - _eval_hooks.push_back(create_result_converter(onnx.elements, _result_values.back(), *_results.back().get())); + _results.push_back(CreateVespaTensor()(vespa)); + _result_converters.emplace_back(i, SelectConvertResult()(onnx.elements, vespa.cell_type())); } } // make sure references to Ort::Value inside _result_values are safe @@ -459,7 +439,7 @@ Onnx::EvalContext::~EvalContext() = default; void Onnx::EvalContext::bind_param(size_t i, const eval::Value ¶m) { - _param_binders[i]->bind(param, _param_values[i]); + _param_binders[i](*this, i, param); } void @@ -470,8 +450,8 @@ Onnx::EvalContext::eval() session.Run(run_opts, _model._input_name_refs.data(), _param_values.data(), _param_values.size(), _model._output_name_refs.data(), _result_values.data(), _result_values.size()); - for (const auto &hook: _eval_hooks) { - hook->invoke(); + for (const auto &entry: _result_converters) { + entry.second(*this, entry.first); } } diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h index 4d2ef6ba50d..85a59824229 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h @@ -95,19 +95,10 @@ public: // all parameter values are expected to be bound per evaluation // output values are pre-allocated and will not change class EvalContext { - public: - struct ParamBinder { - using UP = std::unique_ptr<ParamBinder>; - virtual void bind(const eval::Value &vespa, Ort::Value &onnx) = 0; - virtual ~ParamBinder() {} - }; - struct EvalHook { - using UP = std::unique_ptr<EvalHook>; - virtual void invoke() = 0; - virtual ~EvalHook() {} - }; - private: + using param_fun_t = void (*)(EvalContext &, size_t i, const eval::Value &); + using result_fun_t = void (*)(EvalContext &, size_t i); + static Ort::AllocatorWithDefaultOptions _alloc; const Onnx &_model; @@ -116,10 +107,23 @@ public: std::vector<Ort::Value> _param_values; std::vector<Ort::Value> _result_values; std::vector<eval::Value::UP> _results; - std::vector<ParamBinder::UP> _param_binders; - std::vector<EvalHook::UP> _eval_hooks; + std::vector<param_fun_t> _param_binders; + std::vector<std::pair<size_t,result_fun_t>> _result_converters; + + template <typename T> + static void adapt_param(EvalContext &self, size_t idx, const eval::Value ¶m); + + template <typename SRC, typename DST> + static void convert_param(EvalContext &self, size_t idx, const eval::Value ¶m); + + template <typename SRC, typename DST> + static void convert_result(EvalContext &self, size_t idx); public: + struct SelectAdaptParam; + struct SelectConvertParam; + struct SelectConvertResult; + EvalContext(const Onnx &model, const WireInfo &wire_info); ~EvalContext(); size_t num_params() const { return _param_values.size(); } |