From cb69ce48932ce02f27f5d64e74bcaf2f57b660e7 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Tue, 9 Jul 2019 13:26:07 +0000 Subject: float cells in attribute also assert for cell type consistency in dense tensor view --- eval/src/vespa/eval/tensor/dense/dense_tensor_view.h | 1 + searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 60f85c38659..c09202e50d0 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -46,6 +46,7 @@ public: void accept(TensorVisitor &visitor) const override; protected: void initCellsRef(TypedCells cells_in) { + assert(_typeRef.cell_type() == cells_in.type); _cellsRef = cells_in; } private: diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp index f741002ea5e..11a6839ca59 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp @@ -13,6 +13,7 @@ using vespalib::tensor::Tensor; using vespalib::tensor::DenseTensorView; using vespalib::tensor::MutableDenseTensorView; using vespalib::eval::ValueType; +using CellType = vespalib::eval::ValueType::CellType; namespace search::tensor { @@ -21,12 +22,20 @@ namespace { constexpr size_t MIN_BUFFER_ARRAYS = 1024; constexpr size_t DENSE_TENSOR_ALIGNMENT = 32; +size_t size_of(CellType type) { + switch (type) { + case CellType::DOUBLE: return sizeof(double); + case CellType::FLOAT: return sizeof(float); + } + abort(); +} + } DenseTensorStore::TensorSizeCalc::TensorSizeCalc(const ValueType &type) : _numBoundCells(1u), _numUnboundDims(0u), - _cellSize(sizeof(double)) + _cellSize(size_of(type.cell_type())) { for (const auto & dim : type.dimensions()) { if (dim.is_bound()) { @@ -237,6 +246,7 @@ checkMatchingType(const ValueType &lhs, const ValueType &rhs, size_t numCells) checkNumCells *= rhsItr->size; ++rhsItr; } + assert(lhs.cell_type() == rhs.cell_type()); assert(numCells == checkNumCells); assert(rhsItr == rhsItrEnd); } -- cgit v1.2.3