diff options
Diffstat (limited to 'searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp')
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp index b4b0b9bbc79..ce8cc11026c 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp @@ -3,6 +3,8 @@ #include "tensor_buffer_type_mapper.h" #include "tensor_buffer_operations.h" #include <algorithm> +#include <cmath> +#include <limits> namespace search::tensor { @@ -12,15 +14,29 @@ TensorBufferTypeMapper::TensorBufferTypeMapper() { } -TensorBufferTypeMapper::TensorBufferTypeMapper(uint32_t max_small_subspaces_type_id, TensorBufferOperations* ops) +TensorBufferTypeMapper::TensorBufferTypeMapper(uint32_t max_small_subspaces_type_id, double grow_factor, TensorBufferOperations* ops) : _array_sizes(), _ops(ops) { _array_sizes.reserve(max_small_subspaces_type_id + 1); _array_sizes.emplace_back(0); // type id 0 uses LargeSubspacesBufferType + uint32_t num_subspaces = 0; + size_t prev_array_size = 0u; + size_t array_size = 0u; for (uint32_t type_id = 1; type_id <= max_small_subspaces_type_id; ++type_id) { - auto num_subspaces = type_id - 1; - _array_sizes.emplace_back(_ops->get_array_size(num_subspaces)); + if (type_id > 1) { + num_subspaces = std::max(num_subspaces + 1, static_cast<uint32_t>(std::floor(num_subspaces * grow_factor))); + } + array_size = _ops->get_buffer_size(num_subspaces); + while (array_size <= prev_array_size) { + ++num_subspaces; + array_size = _ops->get_buffer_size(num_subspaces); + } + if (array_size > std::numeric_limits<uint32_t>::max()) { + break; + } + _array_sizes.emplace_back(array_size); + prev_array_size = array_size; } } @@ -44,4 +60,11 @@ TensorBufferTypeMapper::get_array_size(uint32_t type_id) const return _array_sizes[type_id]; } +uint32_t +TensorBufferTypeMapper::get_max_small_array_type_id(uint32_t max_small_array_type_id) const noexcept +{ + auto clamp_type_id = _array_sizes.size() - 1; + return (clamp_type_id < max_small_array_type_id) ? clamp_type_id : max_small_array_type_id; +} + } |