aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp29
1 files changed, 26 insertions, 3 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp
index b4b0b9bbc79..ce8cc11026c 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_type_mapper.cpp
@@ -3,6 +3,8 @@
#include "tensor_buffer_type_mapper.h"
#include "tensor_buffer_operations.h"
#include <algorithm>
+#include <cmath>
+#include <limits>
namespace search::tensor {
@@ -12,15 +14,29 @@ TensorBufferTypeMapper::TensorBufferTypeMapper()
{
}
-TensorBufferTypeMapper::TensorBufferTypeMapper(uint32_t max_small_subspaces_type_id, TensorBufferOperations* ops)
+TensorBufferTypeMapper::TensorBufferTypeMapper(uint32_t max_small_subspaces_type_id, double grow_factor, TensorBufferOperations* ops)
: _array_sizes(),
_ops(ops)
{
_array_sizes.reserve(max_small_subspaces_type_id + 1);
_array_sizes.emplace_back(0); // type id 0 uses LargeSubspacesBufferType
+ uint32_t num_subspaces = 0;
+ size_t prev_array_size = 0u;
+ size_t array_size = 0u;
for (uint32_t type_id = 1; type_id <= max_small_subspaces_type_id; ++type_id) {
- auto num_subspaces = type_id - 1;
- _array_sizes.emplace_back(_ops->get_array_size(num_subspaces));
+ if (type_id > 1) {
+ num_subspaces = std::max(num_subspaces + 1, static_cast<uint32_t>(std::floor(num_subspaces * grow_factor)));
+ }
+ array_size = _ops->get_buffer_size(num_subspaces);
+ while (array_size <= prev_array_size) {
+ ++num_subspaces;
+ array_size = _ops->get_buffer_size(num_subspaces);
+ }
+ if (array_size > std::numeric_limits<uint32_t>::max()) {
+ break;
+ }
+ _array_sizes.emplace_back(array_size);
+ prev_array_size = array_size;
}
}
@@ -44,4 +60,11 @@ TensorBufferTypeMapper::get_array_size(uint32_t type_id) const
return _array_sizes[type_id];
}
+uint32_t
+TensorBufferTypeMapper::get_max_small_array_type_id(uint32_t max_small_array_type_id) const noexcept
+{
+ auto clamp_type_id = _array_sizes.size() - 1;
+ return (clamp_type_id < max_small_array_type_id) ? clamp_type_id : max_small_array_type_id;
+}
+
}