diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-09-28 13:37:02 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-09-28 13:37:02 +0000 |
commit | e533d6f97052838eaa47268f086d9bbc560c20a9 (patch) | |
tree | 7d840c20479bc20fb41f2ee45f6ac075a8710b39 /eval | |
parent | 42e08d6a2649d4e9421ba64de1d310e1e82cc262 (diff) |
we need templated SparseTensorValue after all
Diffstat (limited to 'eval')
4 files changed, 18 insertions, 14 deletions
diff --git a/eval/src/tests/tensor/default_value_builder_factory/default_value_builder_factory_test.cpp b/eval/src/tests/tensor/default_value_builder_factory/default_value_builder_factory_test.cpp index 5663dda592f..a04e66ecd38 100644 --- a/eval/src/tests/tensor/default_value_builder_factory/default_value_builder_factory_test.cpp +++ b/eval/src/tests/tensor/default_value_builder_factory/default_value_builder_factory_test.cpp @@ -28,7 +28,7 @@ TEST(MakeInputTest, print_some_test_input) { EXPECT_TRUE(dynamic_cast<DoubleValue *>(dbl.get())); EXPECT_TRUE(dynamic_cast<DenseTensorView *>(trivial.get())); EXPECT_TRUE(dynamic_cast<DenseTensorView *>(dense.get())); - EXPECT_TRUE(dynamic_cast<SparseTensorValue *>(sparse.get())); + EXPECT_TRUE(dynamic_cast<SparseTensorValue<double> *>(sparse.get())); EXPECT_TRUE(dynamic_cast<PackedMixedTensor *>(mixed.get())); EXPECT_EQ(dbl->as_double(), 3.0); diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.cpp index 89cfb89ae8e..ba9c78aeb7b 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.cpp @@ -216,25 +216,29 @@ SparseTensorValueIndex::create_view(const std::vector<size_t> &dims) const //----------------------------------------------------------------------------- template<typename T> -SparseTensorValue::SparseTensorValue(const eval::ValueType &type_in, const SparseTensorValueIndex &index_in, ConstArrayRef<T> cells_in) +SparseTensorValue<T>::SparseTensorValue(const eval::ValueType &type_in, const SparseTensorValueIndex &index_in, ConstArrayRef<T> cells_in) : _type(type_in), _index(index_in.num_mapped_dims), _cells(), _stash(needed_memory_for(index_in.map, cells_in)) { copyMap(_index.map, index_in.map, _stash); - _cells = TypedCells(_stash.copy_array<T>(cells_in)); + _cells = _stash.copy_array<T>(cells_in); } -SparseTensorValue::SparseTensorValue(eval::ValueType &&type_in, SparseTensorValueIndex &&index_in, TypedCells cells_in, Stash &&stash_in) +template<typename T> +SparseTensorValue<T>::SparseTensorValue(eval::ValueType &&type_in, SparseTensorValueIndex &&index_in, ConstArrayRef<T> &&cells_in, Stash &&stash_in) : _type(std::move(type_in)), _index(std::move(index_in)), - _cells(cells_in), + _cells(std::move(cells_in)), _stash(std::move(stash_in)) { } -SparseTensorValue::~SparseTensorValue() = default; +template<typename T> SparseTensorValue<T>::~SparseTensorValue() = default; + +template class SparseTensorValue<float>; +template class SparseTensorValue<double>; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.h index a565d6edc1f..ad916021bbc 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value.h @@ -29,22 +29,22 @@ struct SparseTensorValueIndex : public vespalib::eval::Value::Index * improve CPU cache and TLB hit ratio, relative to SimpleTensor * implementation. */ +template<typename T> class SparseTensorValue : public vespalib::eval::Value { private: eval::ValueType _type; SparseTensorValueIndex _index; - TypedCells _cells; + ConstArrayRef<T> _cells; Stash _stash; public: - template<typename T> SparseTensorValue(const eval::ValueType &type_in, const SparseTensorValueIndex &index_in, ConstArrayRef<T> cells_in); - SparseTensorValue(eval::ValueType &&type_in, SparseTensorValueIndex &&index_in, TypedCells cells_in, Stash &&stash_in); + SparseTensorValue(eval::ValueType &&type_in, SparseTensorValueIndex &&index_in, ConstArrayRef<T> &&cells_in, Stash &&stash_in); ~SparseTensorValue() override; - TypedCells cells() const override { return _cells; } + TypedCells cells() const override { return TypedCells(_cells); } const Index &index() const override { return _index; } diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value_builder.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value_builder.cpp index ca51e101a89..7a670ab8f85 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value_builder.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor_value_builder.cpp @@ -28,10 +28,10 @@ SparseTensorValueBuilder<T>::build(std::unique_ptr<eval::ValueBuilder<T>>) // copy cells to stash: ConstArrayRef<T> tmp_cells = _cells; ConstArrayRef<T> cells_copy = _stash.copy_array<T>(tmp_cells); - return std::make_unique<SparseTensorValue>(std::move(_type), - std::move(_index), - TypedCells(cells_copy), - std::move(_stash)); + return std::make_unique<SparseTensorValue<T>>(std::move(_type), + std::move(_index), + std::move(cells_copy), + std::move(_stash)); } template class SparseTensorValueBuilder<float>; |