diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-09-25 10:34:53 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-09-28 12:06:35 +0000 |
commit | bb304f3b6961292182f5f480a2789e5921746713 (patch) | |
tree | 360d6b92c1b81000bce46c658d1fd9b51975b338 | |
parent | 55b186000b19bec24008eefeae9e4c23a476e91e (diff) |
add DefaultValueBuilderFactory
-rw-r--r-- | eval/src/vespa/eval/tensor/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/default_value_builder_factory.cpp | 57 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/default_value_builder_factory.h | 24 |
3 files changed, 82 insertions, 0 deletions
diff --git a/eval/src/vespa/eval/tensor/CMakeLists.txt b/eval/src/vespa/eval/tensor/CMakeLists.txt index bc0a4d340b8..810dfd6d0b3 100644 --- a/eval/src/vespa/eval/tensor/CMakeLists.txt +++ b/eval/src/vespa/eval/tensor/CMakeLists.txt @@ -2,6 +2,7 @@ vespa_add_library(eval_tensor OBJECT SOURCES default_tensor_engine.cpp + default_value_builder_factory.cpp tensor.cpp tensor_address.cpp tensor_apply.cpp diff --git a/eval/src/vespa/eval/tensor/default_value_builder_factory.cpp b/eval/src/vespa/eval/tensor/default_value_builder_factory.cpp new file mode 100644 index 00000000000..46301b0b5be --- /dev/null +++ b/eval/src/vespa/eval/tensor/default_value_builder_factory.cpp @@ -0,0 +1,57 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "default_value_builder_factory.h" +#include <vespa/vespalib/util/typify.h> +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/double_value_builder.h> +#include <vespa/eval/tensor/dense/dense_tensor_value_builder.h> +#include <vespa/eval/tensor/mixed/packed_mixed_tensor_builder.h> +#include <vespa/eval/tensor/sparse/sparse_tensor_value_builder.h> + +using namespace vespalib::eval; + +namespace vespalib::tensor { + +//----------------------------------------------------------------------------- + +namespace { + +struct CreateDefaultValueBuilderBase { + template <typename T> static std::unique_ptr<ValueBuilderBase> invoke(const ValueType &type, + size_t num_mapped_dims_in, + size_t subspace_size_in, + size_t expected_subspaces) + { + assert(check_cell_type<T>(type.cell_type())); + if (type.is_double()) { + return std::make_unique<DoubleValueBuilder>(type, num_mapped_dims_in, subspace_size_in, 1); + } + if (type.is_dense()) { + return std::make_unique<DenseTensorValueBuilder<T>>(type, num_mapped_dims_in, subspace_size_in, 1); + } + if (type.is_sparse()) { + return std::make_unique<SparseTensorValueBuilder<T>>(type, num_mapped_dims_in, subspace_size_in, expected_subspaces); + } + return std::make_unique<packed_mixed_tensor::PackedMixedTensorBuilder<T>>(type, num_mapped_dims_in, subspace_size_in, expected_subspaces); + } +}; + +} // namespace <unnamed> + +//----------------------------------------------------------------------------- + +DefaultValueBuilderFactory::DefaultValueBuilderFactory() = default; +DefaultValueBuilderFactory DefaultValueBuilderFactory::_factory; + +std::unique_ptr<ValueBuilderBase> +DefaultValueBuilderFactory::create_value_builder_base(const ValueType &type, + size_t num_mapped_dims_in, + size_t subspace_size_in, + size_t expected_subspaces) const +{ + return typify_invoke<1,TypifyCellType,CreateDefaultValueBuilderBase>(type.cell_type(), type, num_mapped_dims_in, subspace_size_in, expected_subspaces); +} + +//----------------------------------------------------------------------------- + +} diff --git a/eval/src/vespa/eval/tensor/default_value_builder_factory.h b/eval/src/vespa/eval/tensor/default_value_builder_factory.h new file mode 100644 index 00000000000..67b1391ed78 --- /dev/null +++ b/eval/src/vespa/eval/tensor/default_value_builder_factory.h @@ -0,0 +1,24 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/value_type.h> + +namespace vespalib::tensor { + +/** + * A factory that can generate ValueBuilder + * objects appropriate for the requested type. + */ +struct DefaultValueBuilderFactory : eval::ValueBuilderFactory { +private: + DefaultValueBuilderFactory(); + static DefaultValueBuilderFactory _factory; + ~DefaultValueBuilderFactory() override {} +protected: + std::unique_ptr<eval::ValueBuilderBase> create_value_builder_base(const eval::ValueType &type, + size_t num_mapped_in, size_t subspace_size_in, size_t expect_subspaces) const override; +public: + static const DefaultValueBuilderFactory &get() { return _factory; } +}; + +} // namespace |