summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-09-07 11:51:53 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-07 11:51:53 +0000
commit7dc8177525bb90a5782d2c2c84155ec3adbe7adc (patch)
treec52bf7752176677e79c35787250926b4901e198b /eval
parentce319c6c3dffd5827bf9450b87629270f86e8511 (diff)
package trampolines with targets, change order to match header
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp84
1 files changed, 39 insertions, 45 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
index 28e35827a29..8f1b01a58ab 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
@@ -26,18 +26,6 @@ using vespalib::make_string_short::fmt;
namespace vespalib::tensor {
-struct Onnx::EvalContext::SelectAdaptParam {
- template <typename ...Ts> static auto invoke() { return adapt_param<Ts...>; }
-};
-
-struct Onnx::EvalContext::SelectConvertParam {
- template <typename ...Ts> static auto invoke() { return convert_param<Ts...>; }
-};
-
-struct Onnx::EvalContext::SelectConvertResult {
- template <typename ...Ts> static auto invoke() { return convert_result<Ts...>; }
-};
-
namespace {
struct TypifyOnnxElementType {
@@ -75,6 +63,9 @@ struct CreateOnnxTensor {
template <typename T> static Ort::Value invoke(const std::vector<int64_t> &sizes, OrtAllocator *alloc) {
return Ort::Value::CreateTensor<T>(alloc, sizes.data(), sizes.size());
}
+ Ort::Value operator()(const Onnx::TensorType &type, OrtAllocator *alloc) {
+ return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc);
+ }
};
struct CreateVespaTensorRef {
@@ -83,6 +74,9 @@ struct CreateVespaTensorRef {
ConstArrayRef<T> cells(value.GetTensorMutableData<T>(), num_cells);
return std::make_unique<DenseTensorView>(type_ref, TypedCells(cells));
}
+ eval::Value::UP operator()(const eval::ValueType &type_ref, Ort::Value &value) {
+ return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value);
+ }
};
struct CreateVespaTensor {
@@ -91,6 +85,9 @@ struct CreateVespaTensor {
std::vector<T> cells(num_cells, T{});
return std::make_unique<DenseTensor<T>>(type, std::move(cells));
}
+ eval::Value::UP operator()(const eval::ValueType &type) {
+ return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type);
+ }
};
//-----------------------------------------------------------------------------
@@ -103,32 +100,6 @@ template <typename E1, typename E2> bool is_same_type(E1 e1, E2 e2) {
return typify_invoke<2,MyTypify,IsSameType>(e1, e2);
}
-Ort::Value create_onnx_tensor(const Onnx::TensorType &type, OrtAllocator *alloc) {
- return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc);
-}
-
-eval::Value::UP create_vespa_tensor_ref(const eval::ValueType &type_ref, Ort::Value &value) {
- return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value);
-}
-
-eval::Value::UP create_vespa_tensor(const eval::ValueType &type) {
- return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type);
-}
-
-//-----------------------------------------------------------------------------
-
-auto select_adapt_param(eval::ValueType::CellType ct) {
- return typify_invoke<1,MyTypify,Onnx::EvalContext::SelectAdaptParam>(ct);
-}
-
-auto select_convert_param(eval::ValueType::CellType ct, Onnx::ElementType et) {
- return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertParam>(ct, et);
-}
-
-auto select_convert_result(Onnx::ElementType et, eval::ValueType::CellType ct) {
- return typify_invoke<2,MyTypify,Onnx::EvalContext::SelectConvertResult>(et, ct);
-}
-
//-----------------------------------------------------------------------------
auto convert_optimize(Onnx::Optimize optimize) {
@@ -396,6 +367,29 @@ Onnx::EvalContext::convert_result(EvalContext &self, size_t idx)
}
}
+struct Onnx::EvalContext::SelectAdaptParam {
+ template <typename ...Ts> static auto invoke() { return adapt_param<Ts...>; }
+ auto operator()(eval::ValueType::CellType ct) {
+ return typify_invoke<1,MyTypify,SelectAdaptParam>(ct);
+ }
+};
+
+struct Onnx::EvalContext::SelectConvertParam {
+ template <typename ...Ts> static auto invoke() { return convert_param<Ts...>; }
+ auto operator()(eval::ValueType::CellType ct, Onnx::ElementType et) {
+ return typify_invoke<2,MyTypify,SelectConvertParam>(ct, et);
+ }
+};
+
+struct Onnx::EvalContext::SelectConvertResult {
+ template <typename ...Ts> static auto invoke() { return convert_result<Ts...>; }
+ auto operator()(Onnx::ElementType et, eval::ValueType::CellType ct) {
+ return typify_invoke<2,MyTypify,SelectConvertResult>(et, ct);
+ }
+};
+
+//-----------------------------------------------------------------------------
+
Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
: _model(model),
_wire_info(wire_info),
@@ -419,21 +413,21 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
const auto &onnx = _wire_info.onnx_inputs[i];
if (is_same_type(vespa.cell_type(), onnx.elements)) {
_param_values.push_back(Ort::Value(nullptr));
- _param_binders.push_back(select_adapt_param(vespa.cell_type()));
+ _param_binders.push_back(SelectAdaptParam()(vespa.cell_type()));
} else {
- _param_values.push_back(create_onnx_tensor(onnx, _alloc));
- _param_binders.push_back(select_convert_param(vespa.cell_type(), onnx.elements));
+ _param_values.push_back(CreateOnnxTensor()(onnx, _alloc));
+ _param_binders.push_back(SelectConvertParam()(vespa.cell_type(), onnx.elements));
}
}
for (size_t i = 0; i < _model.outputs().size(); ++i) {
const auto &vespa = _wire_info.vespa_outputs[i];
const auto &onnx = _wire_info.onnx_outputs[i];
- _result_values.push_back(create_onnx_tensor(onnx, _alloc));
+ _result_values.push_back(CreateOnnxTensor()(onnx, _alloc));
if (is_same_type(vespa.cell_type(), onnx.elements)) {
- _results.push_back(create_vespa_tensor_ref(vespa, _result_values.back()));
+ _results.push_back(CreateVespaTensorRef()(vespa, _result_values.back()));
} else {
- _results.push_back(create_vespa_tensor(vespa));
- _result_converters.emplace_back(i, select_convert_result(onnx.elements, vespa.cell_type()));
+ _results.push_back(CreateVespaTensor()(vespa));
+ _result_converters.emplace_back(i, SelectConvertResult()(onnx.elements, vespa.cell_type()));
}
}
// make sure references to Ort::Value inside _result_values are safe