summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-05-26 11:52:40 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-05-26 12:13:34 +0000
commitc7d738acb70c80f37804608ec2bb45fcf46748de (patch)
tree0852ef41a5415b31b20f5bbab3ce0b8b96006262 /eval
parent0ad9385f7c77711ab64818a78ce9fcc9cfe2e2dc (diff)
disable use of arena allocator
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp7
-rw-r--r--eval/src/vespa/eval/onnx/onnx_wrapper.cpp3
2 files changed, 9 insertions, 1 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 9b44dd7519e..54f958f8111 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -436,4 +436,11 @@ TEST(OnnxTest, inspect_float_to_int8_conversion) {
//-------------------------------------------------------------------------
}
+TEST(OnnxTest, default_allocator_type) {
+ Ort::AllocatorWithDefaultOptions default_alloc;
+ OrtAllocatorType res = Invalid;
+ Ort::ThrowOnError(Ort::GetApi().MemoryInfoGetType(default_alloc.GetInfo(), &res));
+ fprintf(stderr, "default allocator type: %d\n", int(res));
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
index 3a593f491d8..f848c421c9d 100644
--- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
@@ -414,7 +414,7 @@ struct Onnx::EvalContext::SelectConvertResult {
Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
: _model(model),
_wire_info(wire_info),
- _cpu_memory(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)),
+ _cpu_memory(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault)),
_param_values(),
_result_values(),
_results(),
@@ -535,6 +535,7 @@ Onnx::Onnx(const vespalib::string &model_file, Optimize optimize)
_options.SetIntraOpNumThreads(1);
_options.SetInterOpNumThreads(1);
_options.SetGraphOptimizationLevel(convert_optimize(optimize));
+ _options.DisableCpuMemArena();
_session = Ort::Session(_shared.env(), model_file.c_str(), _options);
extract_meta_data();
}