diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-09-07 11:51:53 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-09-07 11:51:53 +0000 |
commit | 7dc8177525bb90a5782d2c2c84155ec3adbe7adc (patch) | |
tree | c52bf7752176677e79c35787250926b4901e198b /eval | |
parent | ce319c6c3dffd5827bf9450b87629270f86e8511 (diff) |
package trampolines with targets, change order to match header
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp | 84 |
1 files changed, 39 insertions, 45 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp index 28e35827a29..8f1b01a58ab 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp @@ -26,18 +26,6 @@ using vespalib::make_string_short::fmt; namespace vespalib::tensor { -struct Onnx::EvalContext::SelectAdaptParam { - template <typename ...Ts> static auto invoke() { return adapt_param<Ts...>; } -}; - -struct Onnx::EvalContext::SelectConvertParam { - template <typename ...Ts> static auto invoke() { return convert_param<Ts...>; } -}; - -struct Onnx::EvalContext::SelectConvertResult { - template <typename ...Ts> static auto invoke() { return convert_result<Ts...>; } -}; - namespace { struct TypifyOnnxElementType { @@ -75,6 +63,9 @@ struct CreateOnnxTensor { template <typename T> static Ort::Value invoke(const std::vector<int64_t> &sizes, OrtAllocator *alloc) { return Ort::Value::CreateTensor<T>(alloc, sizes.data(), sizes.size()); } + Ort::Value operator()(const Onnx::TensorType &type, OrtAllocator *alloc) { + return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc); + } }; struct CreateVespaTensorRef { @@ -83,6 +74,9 @@ struct CreateVespaTensorRef { ConstArrayRef<T> cells(value.GetTensorMutableData<T>(), num_cells); return std::make_unique<DenseTensorView>(type_ref, TypedCells(cells)); } + eval::Value::UP operator()(const eval::ValueType &type_ref, Ort::Value &value) { + return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value); + } }; struct CreateVespaTensor { @@ -91,6 +85,9 @@ struct CreateVespaTensor { std::vector<T> cells(num_cells, T{}); return std::make_unique<DenseTensor<T>>(type, std::move(cells)); } + eval::Value::UP operator()(const eval::ValueType &type) { + return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type); + } }; //----------------------------------------------------------------------------- @@ -103,32 +100,6 @@ template <typename E1, typename E2> bool is_same_type(E1 e1, E2 e2) { return typify_invoke<2,MyTypify,IsSameType>(e1, e2); } -Ort::Value create_onnx_tensor(const Onnx::TensorType &type, OrtAllocator *alloc) { - return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc); -} - -eval::Value::UP create_vespa_tensor_ref(const eval::ValueType &type_ref, Ort::Value &value) { - return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value); -} - -eval::Value::UP create_vespa_tensor(const eval::ValueType &type) { - return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type); -} - -//----------------------------------------------------------------------------- - -auto select_adapt_param(eval::ValueType::CellType ct) { - return typify_invoke<1,MyTypify,Onnx::EvalContext::SelectAdaptParam>(ct); -} - -auto select_convert_param(eval::ValueType::CellType ct, Onnx::ElementType et) { - return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertParam>(ct, et); -} - -auto select_convert_result(Onnx::ElementType et, eval::ValueType::CellType ct) { - return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertResult>(et, ct); -} - //----------------------------------------------------------------------------- auto convert_optimize(Onnx::Optimize optimize) { @@ -396,6 +367,29 @@ Onnx::EvalContext::convert_result(EvalContext &self, size_t idx) } } +struct Onnx::EvalContext::SelectAdaptParam { + template <typename ...Ts> static auto invoke() { return adapt_param<Ts...>; } + auto operator()(eval::ValueType::CellType ct) { + return typify_invoke<1,MyTypify,SelectAdaptParam>(ct); + } +}; + +struct Onnx::EvalContext::SelectConvertParam { + template <typename ...Ts> static auto invoke() { return convert_param<Ts...>; } + auto operator()(eval::ValueType::CellType ct, Onnx::ElementType et) { + return typify_invoke<2,MyTypify,SelectConvertParam>(ct, et); + } +}; + +struct Onnx::EvalContext::SelectConvertResult { + template <typename ...Ts> static auto invoke() { return convert_result<Ts...>; } + auto operator()(Onnx::ElementType et, eval::ValueType::CellType ct) { + return typify_invoke<2,MyTypify,SelectConvertResult>(et, ct); + } +}; + +//----------------------------------------------------------------------------- + Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) : _model(model), _wire_info(wire_info), @@ -419,21 +413,21 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) const auto &onnx = _wire_info.onnx_inputs[i]; if (is_same_type(vespa.cell_type(), onnx.elements)) { _param_values.push_back(Ort::Value(nullptr)); - _param_binders.push_back(select_adapt_param(vespa.cell_type())); + _param_binders.push_back(SelectAdaptParam()(vespa.cell_type())); } else { - _param_values.push_back(create_onnx_tensor(onnx, _alloc)); - _param_binders.push_back(select_convert_param(vespa.cell_type(), onnx.elements)); + _param_values.push_back(CreateOnnxTensor()(onnx, _alloc)); + _param_binders.push_back(SelectConvertParam()(vespa.cell_type(), onnx.elements)); } } for (size_t i = 0; i < _model.outputs().size(); ++i) { const auto &vespa = _wire_info.vespa_outputs[i]; const auto &onnx = _wire_info.onnx_outputs[i]; - _result_values.push_back(create_onnx_tensor(onnx, _alloc)); + _result_values.push_back(CreateOnnxTensor()(onnx, _alloc)); if (is_same_type(vespa.cell_type(), onnx.elements)) { - _results.push_back(create_vespa_tensor_ref(vespa, _result_values.back())); + _results.push_back(CreateVespaTensorRef()(vespa, _result_values.back())); } else { - _results.push_back(create_vespa_tensor(vespa)); - _result_converters.emplace_back(i, select_convert_result(onnx.elements, vespa.cell_type())); + _results.push_back(CreateVespaTensor()(vespa)); + _result_converters.emplace_back(i, SelectConvertResult()(onnx.elements, vespa.cell_type())); } } // make sure references to Ort::Value inside _result_values are safe |