summaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-09-25 10:34:53 +0000
committerArne Juul <arnej@verizonmedia.com>2020-09-28 12:06:35 +0000
commitbb304f3b6961292182f5f480a2789e5921746713 (patch)
tree360d6b92c1b81000bce46c658d1fd9b51975b338 /eval/src
parent55b186000b19bec24008eefeae9e4c23a476e91e (diff)
add DefaultValueBuilderFactory
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/vespa/eval/tensor/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/tensor/default_value_builder_factory.cpp57
-rw-r--r--eval/src/vespa/eval/tensor/default_value_builder_factory.h24
3 files changed, 82 insertions, 0 deletions
diff --git a/eval/src/vespa/eval/tensor/CMakeLists.txt b/eval/src/vespa/eval/tensor/CMakeLists.txt
index bc0a4d340b8..810dfd6d0b3 100644
--- a/eval/src/vespa/eval/tensor/CMakeLists.txt
+++ b/eval/src/vespa/eval/tensor/CMakeLists.txt
@@ -2,6 +2,7 @@
vespa_add_library(eval_tensor OBJECT
SOURCES
default_tensor_engine.cpp
+ default_value_builder_factory.cpp
tensor.cpp
tensor_address.cpp
tensor_apply.cpp
diff --git a/eval/src/vespa/eval/tensor/default_value_builder_factory.cpp b/eval/src/vespa/eval/tensor/default_value_builder_factory.cpp
new file mode 100644
index 00000000000..46301b0b5be
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/default_value_builder_factory.cpp
@@ -0,0 +1,57 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "default_value_builder_factory.h"
+#include <vespa/vespalib/util/typify.h>
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/double_value_builder.h>
+#include <vespa/eval/tensor/dense/dense_tensor_value_builder.h>
+#include <vespa/eval/tensor/mixed/packed_mixed_tensor_builder.h>
+#include <vespa/eval/tensor/sparse/sparse_tensor_value_builder.h>
+
+using namespace vespalib::eval;
+
+namespace vespalib::tensor {
+
+//-----------------------------------------------------------------------------
+
+namespace {
+
+struct CreateDefaultValueBuilderBase {
+ template <typename T> static std::unique_ptr<ValueBuilderBase> invoke(const ValueType &type,
+ size_t num_mapped_dims_in,
+ size_t subspace_size_in,
+ size_t expected_subspaces)
+ {
+ assert(check_cell_type<T>(type.cell_type()));
+ if (type.is_double()) {
+ return std::make_unique<DoubleValueBuilder>(type, num_mapped_dims_in, subspace_size_in, 1);
+ }
+ if (type.is_dense()) {
+ return std::make_unique<DenseTensorValueBuilder<T>>(type, num_mapped_dims_in, subspace_size_in, 1);
+ }
+ if (type.is_sparse()) {
+ return std::make_unique<SparseTensorValueBuilder<T>>(type, num_mapped_dims_in, subspace_size_in, expected_subspaces);
+ }
+ return std::make_unique<packed_mixed_tensor::PackedMixedTensorBuilder<T>>(type, num_mapped_dims_in, subspace_size_in, expected_subspaces);
+ }
+};
+
+} // namespace <unnamed>
+
+//-----------------------------------------------------------------------------
+
+DefaultValueBuilderFactory::DefaultValueBuilderFactory() = default;
+DefaultValueBuilderFactory DefaultValueBuilderFactory::_factory;
+
+std::unique_ptr<ValueBuilderBase>
+DefaultValueBuilderFactory::create_value_builder_base(const ValueType &type,
+ size_t num_mapped_dims_in,
+ size_t subspace_size_in,
+ size_t expected_subspaces) const
+{
+ return typify_invoke<1,TypifyCellType,CreateDefaultValueBuilderBase>(type.cell_type(), type, num_mapped_dims_in, subspace_size_in, expected_subspaces);
+}
+
+//-----------------------------------------------------------------------------
+
+}
diff --git a/eval/src/vespa/eval/tensor/default_value_builder_factory.h b/eval/src/vespa/eval/tensor/default_value_builder_factory.h
new file mode 100644
index 00000000000..67b1391ed78
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/default_value_builder_factory.h
@@ -0,0 +1,24 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/value_type.h>
+
+namespace vespalib::tensor {
+
+/**
+ * A factory that can generate ValueBuilder
+ * objects appropriate for the requested type.
+ */
+struct DefaultValueBuilderFactory : eval::ValueBuilderFactory {
+private:
+ DefaultValueBuilderFactory();
+ static DefaultValueBuilderFactory _factory;
+ ~DefaultValueBuilderFactory() override {}
+protected:
+ std::unique_ptr<eval::ValueBuilderBase> create_value_builder_base(const eval::ValueType &type,
+ size_t num_mapped_in, size_t subspace_size_in, size_t expect_subspaces) const override;
+public:
+ static const DefaultValueBuilderFactory &get() { return _factory; }
+};
+
+} // namespace