From 7dc8177525bb90a5782d2c2c84155ec3adbe7adc Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Mon, 7 Sep 2020 11:51:53 +0000 Subject: package trampolines with targets, change order to match header --- eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp | 84 +++++++++++------------ 1 file changed, 39 insertions(+), 45 deletions(-) (limited to 'eval') diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp index 28e35827a29..8f1b01a58ab 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp @@ -26,18 +26,6 @@ using vespalib::make_string_short::fmt; namespace vespalib::tensor { -struct Onnx::EvalContext::SelectAdaptParam { - template static auto invoke() { return adapt_param; } -}; - -struct Onnx::EvalContext::SelectConvertParam { - template static auto invoke() { return convert_param; } -}; - -struct Onnx::EvalContext::SelectConvertResult { - template static auto invoke() { return convert_result; } -}; - namespace { struct TypifyOnnxElementType { @@ -75,6 +63,9 @@ struct CreateOnnxTensor { template static Ort::Value invoke(const std::vector &sizes, OrtAllocator *alloc) { return Ort::Value::CreateTensor(alloc, sizes.data(), sizes.size()); } + Ort::Value operator()(const Onnx::TensorType &type, OrtAllocator *alloc) { + return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc); + } }; struct CreateVespaTensorRef { @@ -83,6 +74,9 @@ struct CreateVespaTensorRef { ConstArrayRef cells(value.GetTensorMutableData(), num_cells); return std::make_unique(type_ref, TypedCells(cells)); } + eval::Value::UP operator()(const eval::ValueType &type_ref, Ort::Value &value) { + return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value); + } }; struct CreateVespaTensor { @@ -91,6 +85,9 @@ struct CreateVespaTensor { std::vector cells(num_cells, T{}); return std::make_unique>(type, std::move(cells)); } + eval::Value::UP operator()(const eval::ValueType &type) { + return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type); + } }; //----------------------------------------------------------------------------- @@ -103,32 +100,6 @@ template bool is_same_type(E1 e1, E2 e2) { return typify_invoke<2,MyTypify,IsSameType>(e1, e2); } -Ort::Value create_onnx_tensor(const Onnx::TensorType &type, OrtAllocator *alloc) { - return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc); -} - -eval::Value::UP create_vespa_tensor_ref(const eval::ValueType &type_ref, Ort::Value &value) { - return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value); -} - -eval::Value::UP create_vespa_tensor(const eval::ValueType &type) { - return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type); -} - -//----------------------------------------------------------------------------- - -auto select_adapt_param(eval::ValueType::CellType ct) { - return typify_invoke<1,MyTypify,Onnx::EvalContext::SelectAdaptParam>(ct); -} - -auto select_convert_param(eval::ValueType::CellType ct, Onnx::ElementType et) { - return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertParam>(ct, et); -} - -auto select_convert_result(Onnx::ElementType et, eval::ValueType::CellType ct) { - return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertResult>(et, ct); -} - //----------------------------------------------------------------------------- auto convert_optimize(Onnx::Optimize optimize) { @@ -396,6 +367,29 @@ Onnx::EvalContext::convert_result(EvalContext &self, size_t idx) } } +struct Onnx::EvalContext::SelectAdaptParam { + template static auto invoke() { return adapt_param; } + auto operator()(eval::ValueType::CellType ct) { + return typify_invoke<1,MyTypify,SelectAdaptParam>(ct); + } +}; + +struct Onnx::EvalContext::SelectConvertParam { + template static auto invoke() { return convert_param; } + auto operator()(eval::ValueType::CellType ct, Onnx::ElementType et) { + return typify_invoke<2,MyTypify,SelectConvertParam>(ct, et); + } +}; + +struct Onnx::EvalContext::SelectConvertResult { + template static auto invoke() { return convert_result; } + auto operator()(Onnx::ElementType et, eval::ValueType::CellType ct) { + return typify_invoke<2,MyTypify,SelectConvertResult>(et, ct); + } +}; + +//----------------------------------------------------------------------------- + Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) : _model(model), _wire_info(wire_info), @@ -419,21 +413,21 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) const auto &onnx = _wire_info.onnx_inputs[i]; if (is_same_type(vespa.cell_type(), onnx.elements)) { _param_values.push_back(Ort::Value(nullptr)); - _param_binders.push_back(select_adapt_param(vespa.cell_type())); + _param_binders.push_back(SelectAdaptParam()(vespa.cell_type())); } else { - _param_values.push_back(create_onnx_tensor(onnx, _alloc)); - _param_binders.push_back(select_convert_param(vespa.cell_type(), onnx.elements)); + _param_values.push_back(CreateOnnxTensor()(onnx, _alloc)); + _param_binders.push_back(SelectConvertParam()(vespa.cell_type(), onnx.elements)); } } for (size_t i = 0; i < _model.outputs().size(); ++i) { const auto &vespa = _wire_info.vespa_outputs[i]; const auto &onnx = _wire_info.onnx_outputs[i]; - _result_values.push_back(create_onnx_tensor(onnx, _alloc)); + _result_values.push_back(CreateOnnxTensor()(onnx, _alloc)); if (is_same_type(vespa.cell_type(), onnx.elements)) { - _results.push_back(create_vespa_tensor_ref(vespa, _result_values.back())); + _results.push_back(CreateVespaTensorRef()(vespa, _result_values.back())); } else { - _results.push_back(create_vespa_tensor(vespa)); - _result_converters.emplace_back(i, select_convert_result(onnx.elements, vespa.cell_type())); + _results.push_back(CreateVespaTensor()(vespa)); + _result_converters.emplace_back(i, SelectConvertResult()(onnx.elements, vespa.cell_type())); } } // make sure references to Ort::Value inside _result_values are safe -- cgit v1.2.3