summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2021-03-09 14:35:57 +0100
committerGitHub <noreply@github.com>2021-03-09 14:35:57 +0100
commita0a9b9b55e0f9908f1d0dbca1c199a4a958f50ba (patch)
tree8bd746554c589df92750802710416a58b71b868a
parent934be121d5c905660b14ea2d8a18798db834fa64 (diff)
parent8a953448867ff62d63992bfd4c64ef4f6ec7419b (diff)
Merge pull request #16857 from vespa-engine/arnej/use-cell-size-utility
avoid explicit switch on cell types
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp13
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h7
3 files changed, 8 insertions, 16 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp
index 362e1b45266..cc43a694a69 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp
@@ -62,13 +62,11 @@ void
DenseTensorAttributeSaver::save_tensor_store(BufferWriter& writer) const
{
const uint32_t docIdLimit(_refs.size());
- const uint32_t cellSize = _tensorStore.getCellSize();
for (uint32_t lid = 0; lid < docIdLimit; ++lid) {
if (_refs[lid].valid()) {
auto raw = _tensorStore.getRawBuffer(_refs[lid]);
writer.write(&tensorIsPresent, sizeof(tensorIsPresent));
- size_t numCells = _tensorStore.getNumCells();
- size_t rawLen = numCells * cellSize;
+ size_t rawLen = _tensorStore.getBufSize();
writer.write(static_cast<const char *>(raw), rawLen);
} else {
writer.write(&tensorIsNotPresent, sizeof(tensorIsNotPresent));
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
index e99ba196224..13796d35dec 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
@@ -6,9 +6,10 @@
#include <vespa/vespalib/util/memory_allocator.h>
using vespalib::datastore::Handle;
+using vespalib::eval::CellType;
+using vespalib::eval::CellTypeUtils;
using vespalib::eval::Value;
using vespalib::eval::ValueType;
-using CellType = vespalib::eval::CellType;
namespace search::tensor {
@@ -17,14 +18,6 @@ namespace {
constexpr size_t MIN_BUFFER_ARRAYS = 1024;
constexpr size_t DENSE_TENSOR_ALIGNMENT = 32;
-size_t size_of(CellType type) {
- switch (type) {
- case CellType::DOUBLE: return sizeof(double);
- case CellType::FLOAT: return sizeof(float);
- }
- abort();
-}
-
size_t my_align(size_t size, size_t alignment) {
size += alignment - 1;
return (size - (size % alignment));
@@ -34,7 +27,7 @@ size_t my_align(size_t size, size_t alignment) {
DenseTensorStore::TensorSizeCalc::TensorSizeCalc(const ValueType &type)
: _numCells(1u),
- _cellSize(size_of(type.cell_type()))
+ _cell_type(type.cell_type())
{
for (const auto &dim: type.dimensions()) {
_numCells *= dim.size;
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
index aa5a7993eaf..dad28642e67 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
@@ -24,10 +24,12 @@ public:
struct TensorSizeCalc
{
size_t _numCells; // product of dimension sizes
- uint32_t _cellSize; // size of a cell (e.g. double => 8, float => 4)
+ vespalib::eval::CellType _cell_type;
TensorSizeCalc(const ValueType &type);
- size_t bufSize() const { return (_numCells * _cellSize); }
+ size_t bufSize() const {
+ return vespalib::eval::CellTypeUtils::mem_size(_cell_type, _numCells);
+ }
size_t alignedSize() const;
};
@@ -60,7 +62,6 @@ public:
const ValueType &type() const { return _type; }
size_t getNumCells() const { return _tensorSizeCalc._numCells; }
- uint32_t getCellSize() const { return _tensorSizeCalc._cellSize; }
size_t getBufSize() const { return _tensorSizeCalc.bufSize(); }
const void *getRawBuffer(RefType ref) const;
vespalib::datastore::Handle<char> allocRawBuffer();