summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-09-06 13:51:59 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-06 15:14:36 +0000
commitce319c6c3dffd5827bf9450b87629270f86e8511 (patch)
treeb6a578dff0ad2e8c1aa44d13539fe1cc3cda1c1d /eval
parent466bc196eee4571e2197624f17b8a7d8aee38cf0 (diff)
use functions, not objects
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp140
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.h32
2 files changed, 81 insertions, 91 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
index 7c29b20f2f4..28e35827a29 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
@@ -26,8 +26,17 @@ using vespalib::make_string_short::fmt;
namespace vespalib::tensor {
-using ParamBinder = Onnx::EvalContext::ParamBinder;
-using EvalHook = Onnx::EvalContext::EvalHook;
+struct Onnx::EvalContext::SelectAdaptParam {
+ template <typename ...Ts> static auto invoke() { return adapt_param<Ts...>; }
+};
+
+struct Onnx::EvalContext::SelectConvertParam {
+ template <typename ...Ts> static auto invoke() { return convert_param<Ts...>; }
+};
+
+struct Onnx::EvalContext::SelectConvertResult {
+ template <typename ...Ts> static auto invoke() { return convert_result<Ts...>; }
+};
namespace {
@@ -84,68 +93,6 @@ struct CreateVespaTensor {
}
};
-template <typename T>
-struct ParamAdapter : ParamBinder {
- const Onnx::TensorType &type;
- const Ort::MemoryInfo &memory;
- ParamAdapter(const Onnx::TensorType &type_in, const Ort::MemoryInfo &memory_in)
- : type(type_in), memory(memory_in) {}
- void bind(const eval::Value &vespa, Ort::Value &onnx) override {
- const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef();
- auto cells = unconstify(cells_ref.typify<T>());
- onnx = Ort::Value::CreateTensor<T>(memory, cells.begin(), cells.size(), type.dimensions.data(), type.dimensions.size());
- }
-};
-
-struct CreateParamAdapter {
- template <typename T> static ParamBinder::UP invoke(const Onnx::TensorType &type, const Ort::MemoryInfo &memory) {
- return std::make_unique<ParamAdapter<T>>(type, memory);
- }
-};
-
-template <typename SRC, typename DST>
-struct ParamConverter : ParamBinder {
- void bind(const eval::Value &vespa, Ort::Value &onnx) override {
- auto cells = static_cast<const DenseTensorView &>(vespa).cellsRef().typify<SRC>();
- size_t n = cells.size();
- const SRC *src = cells.begin();
- DST *dst = onnx.GetTensorMutableData<DST>();
- for (size_t i = 0; i < n; ++i) {
- dst[i] = DST(src[i]);
- }
- }
-};
-
-struct CreateParamConverter {
- template <typename SRC, typename DST> static ParamBinder::UP invoke() {
- return std::make_unique<ParamConverter<SRC,DST>>();
- }
-};
-
-template <typename SRC, typename DST>
-struct ResultConverter : EvalHook {
- Ort::Value &onnx;
- const eval::Value &vespa;
- ResultConverter(Ort::Value &onnx_in, const eval::Value &vespa_in)
- : onnx(onnx_in), vespa(vespa_in) {}
- void invoke() override {
- const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef();
- auto cells = unconstify(cells_ref.typify<DST>());
- size_t n = cells.size();
- DST *dst = cells.begin();
- const SRC *src = onnx.GetTensorMutableData<SRC>();
- for (size_t i = 0; i < n; ++i) {
- dst[i] = DST(src[i]);
- }
- }
-};
-
-struct CreateResultConverter {
- template <typename SRC, typename DST> static EvalHook::UP invoke(Ort::Value &onnx, const eval::Value &vespa) {
- return std::make_unique<ResultConverter<SRC,DST>>(onnx, vespa);
- }
-};
-
//-----------------------------------------------------------------------------
template <typename E> vespalib::string type_name(E enum_value) {
@@ -168,16 +115,18 @@ eval::Value::UP create_vespa_tensor(const eval::ValueType &type) {
return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type);
}
-ParamBinder::UP create_param_adapter(eval::ValueType::CellType ct, const Onnx::TensorType &type, const Ort::MemoryInfo &memory) {
- return typify_invoke<1,MyTypify,CreateParamAdapter>(ct, type, memory);
+//-----------------------------------------------------------------------------
+
+auto select_adapt_param(eval::ValueType::CellType ct) {
+ return typify_invoke<1,MyTypify,Onnx::EvalContext::SelectAdaptParam>(ct);
}
-ParamBinder::UP create_param_converter(eval::ValueType::CellType ct, Onnx::ElementType et) {
- return typify_invoke<2,MyTypify,CreateParamConverter>(ct, et);
+auto select_convert_param(eval::ValueType::CellType ct, Onnx::ElementType et) {
+ return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertParam>(ct, et);
}
-EvalHook::UP create_result_converter(Onnx::ElementType et, Ort::Value &onnx, const eval::Value &vespa) {
- return typify_invoke<2,MyTypify,CreateResultConverter>(et, vespa.type().cell_type(), onnx, vespa);
+auto select_convert_result(Onnx::ElementType et, eval::ValueType::CellType ct) {
+ return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertResult>(et, ct);
}
//-----------------------------------------------------------------------------
@@ -410,6 +359,43 @@ Onnx::WirePlanner::get_wire_info(const Onnx &model) const
Ort::AllocatorWithDefaultOptions Onnx::EvalContext::_alloc;
+template <typename T>
+void
+Onnx::EvalContext::adapt_param(EvalContext &self, size_t idx, const eval::Value &param)
+{
+ const auto &cells_ref = static_cast<const DenseTensorView &>(param).cellsRef();
+ auto cells = unconstify(cells_ref.typify<T>());
+ const auto &sizes = self._wire_info.onnx_inputs[idx].dimensions;
+ self._param_values[idx] = Ort::Value::CreateTensor<T>(self._cpu_memory, cells.begin(), cells.size(), sizes.data(), sizes.size());
+}
+
+template <typename SRC, typename DST>
+void
+Onnx::EvalContext::convert_param(EvalContext &self, size_t idx, const eval::Value &param)
+{
+ auto cells = static_cast<const DenseTensorView &>(param).cellsRef().typify<SRC>();
+ size_t n = cells.size();
+ const SRC *src = cells.begin();
+ DST *dst = self._param_values[idx].GetTensorMutableData<DST>();
+ for (size_t i = 0; i < n; ++i) {
+ dst[i] = DST(src[i]);
+ }
+}
+
+template <typename SRC, typename DST>
+void
+Onnx::EvalContext::convert_result(EvalContext &self, size_t idx)
+{
+ const auto &cells_ref = static_cast<const DenseTensorView &>(*self._results[idx]).cellsRef();
+ auto cells = unconstify(cells_ref.typify<DST>());
+ size_t n = cells.size();
+ DST *dst = cells.begin();
+ const SRC *src = self._result_values[idx].GetTensorMutableData<SRC>();
+ for (size_t i = 0; i < n; ++i) {
+ dst[i] = DST(src[i]);
+ }
+}
+
Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
: _model(model),
_wire_info(wire_info),
@@ -418,7 +404,7 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
_result_values(),
_results(),
_param_binders(),
- _eval_hooks()
+ _result_converters()
{
assert(_wire_info.vespa_inputs.size() == _model.inputs().size());
assert(_wire_info.onnx_inputs.size() == _model.inputs().size());
@@ -433,10 +419,10 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
const auto &onnx = _wire_info.onnx_inputs[i];
if (is_same_type(vespa.cell_type(), onnx.elements)) {
_param_values.push_back(Ort::Value(nullptr));
- _param_binders.push_back(create_param_adapter(vespa.cell_type(), onnx, _cpu_memory));
+ _param_binders.push_back(select_adapt_param(vespa.cell_type()));
} else {
_param_values.push_back(create_onnx_tensor(onnx, _alloc));
- _param_binders.push_back(create_param_converter(vespa.cell_type(), onnx.elements));
+ _param_binders.push_back(select_convert_param(vespa.cell_type(), onnx.elements));
}
}
for (size_t i = 0; i < _model.outputs().size(); ++i) {
@@ -447,7 +433,7 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
_results.push_back(create_vespa_tensor_ref(vespa, _result_values.back()));
} else {
_results.push_back(create_vespa_tensor(vespa));
- _eval_hooks.push_back(create_result_converter(onnx.elements, _result_values.back(), *_results.back().get()));
+ _result_converters.emplace_back(i, select_convert_result(onnx.elements, vespa.cell_type()));
}
}
// make sure references to Ort::Value inside _result_values are safe
@@ -459,7 +445,7 @@ Onnx::EvalContext::~EvalContext() = default;
void
Onnx::EvalContext::bind_param(size_t i, const eval::Value &param)
{
- _param_binders[i]->bind(param, _param_values[i]);
+ _param_binders[i](*this, i, param);
}
void
@@ -470,8 +456,8 @@ Onnx::EvalContext::eval()
session.Run(run_opts,
_model._input_name_refs.data(), _param_values.data(), _param_values.size(),
_model._output_name_refs.data(), _result_values.data(), _result_values.size());
- for (const auto &hook: _eval_hooks) {
- hook->invoke();
+ for (const auto &entry: _result_converters) {
+ entry.second(*this, entry.first);
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
index 4d2ef6ba50d..85a59824229 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
@@ -95,19 +95,10 @@ public:
// all parameter values are expected to be bound per evaluation
// output values are pre-allocated and will not change
class EvalContext {
- public:
- struct ParamBinder {
- using UP = std::unique_ptr<ParamBinder>;
- virtual void bind(const eval::Value &vespa, Ort::Value &onnx) = 0;
- virtual ~ParamBinder() {}
- };
- struct EvalHook {
- using UP = std::unique_ptr<EvalHook>;
- virtual void invoke() = 0;
- virtual ~EvalHook() {}
- };
-
private:
+ using param_fun_t = void (*)(EvalContext &, size_t i, const eval::Value &);
+ using result_fun_t = void (*)(EvalContext &, size_t i);
+
static Ort::AllocatorWithDefaultOptions _alloc;
const Onnx &_model;
@@ -116,10 +107,23 @@ public:
std::vector<Ort::Value> _param_values;
std::vector<Ort::Value> _result_values;
std::vector<eval::Value::UP> _results;
- std::vector<ParamBinder::UP> _param_binders;
- std::vector<EvalHook::UP> _eval_hooks;
+ std::vector<param_fun_t> _param_binders;
+ std::vector<std::pair<size_t,result_fun_t>> _result_converters;
+
+ template <typename T>
+ static void adapt_param(EvalContext &self, size_t idx, const eval::Value &param);
+
+ template <typename SRC, typename DST>
+ static void convert_param(EvalContext &self, size_t idx, const eval::Value &param);
+
+ template <typename SRC, typename DST>
+ static void convert_result(EvalContext &self, size_t idx);
public:
+ struct SelectAdaptParam;
+ struct SelectConvertParam;
+ struct SelectConvertResult;
+
EvalContext(const Onnx &model, const WireInfo &wire_info);
~EvalContext();
size_t num_params() const { return _param_values.size(); }