diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2020-10-13 15:58:31 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-13 15:58:31 +0200 |
commit | 1fb76b9e6dc6028f2678c1c01054f4d37bec6a04 (patch) | |
tree | 2d72ab896f51a9649deb896fc88a363891697432 /eval | |
parent | dd5873778526bb92f53c6cc75b3f8945f86710f2 (diff) | |
parent | 80a1e55004ef23e16029603921b705a7258eac17 (diff) |
Merge pull request #14842 from vespa-engine/havardpe/engine-or-factory-global-switch-with-default
global implementation switch with default
Diffstat (limited to 'eval')
-rw-r--r-- | eval/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/tests/eval/engine_or_factory/CMakeLists.txt | 17 | ||||
-rw-r--r-- | eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp | 25 | ||||
-rw-r--r-- | eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp | 24 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/engine_or_factory.cpp | 65 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/engine_or_factory.h | 7 |
6 files changed, 138 insertions, 1 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index db178342359..eb94a9498cb 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -14,6 +14,7 @@ vespa_define_module( src/tests/eval/aggr src/tests/eval/compile_cache src/tests/eval/compiled_function + src/tests/eval/engine_or_factory src/tests/eval/fast_sparse_map src/tests/eval/function src/tests/eval/function_speed diff --git a/eval/src/tests/eval/engine_or_factory/CMakeLists.txt b/eval/src/tests/eval/engine_or_factory/CMakeLists.txt new file mode 100644 index 00000000000..f0bd0f63251 --- /dev/null +++ b/eval/src/tests/eval/engine_or_factory/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_engine_or_factory_test_app TEST + SOURCES + engine_or_factory_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_engine_or_factory_test_app COMMAND eval_engine_or_factory_test_app) +vespa_add_executable(eval_engine_or_factory_override_test_app TEST + SOURCES + engine_or_factory_override_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_engine_or_factory_override_test_app COMMAND eval_engine_or_factory_override_test_app) diff --git a/eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp b/eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp new file mode 100644 index 00000000000..8480bb7d39e --- /dev/null +++ b/eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp @@ -0,0 +1,25 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/eval/fast_value.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::tensor; + +TEST(EngineOrFactoryOverrideTest, set_can_override_get_result) { + EngineOrFactory::set(FastValueBuilderFactory::get()); + EXPECT_EQ(EngineOrFactory::get().to_string(), "FastValueBuilderFactory"); +} + +TEST(EngineOrFactoryOverrideTest, set_with_same_value_is_allowed) { + EngineOrFactory::set(FastValueBuilderFactory::get()); +} + +TEST(EngineOrFactoryOverrideTest, set_with_another_value_is_not_allowed) { + EXPECT_THROW(EngineOrFactory::set(DefaultTensorEngine::ref()), vespalib::IllegalStateException); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp b/eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp new file mode 100644 index 00000000000..6cb9cc0c89c --- /dev/null +++ b/eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp @@ -0,0 +1,24 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/engine_or_factory.h> +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/eval/fast_value.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::tensor; + +TEST(EngineOrFactoryTest, default_is_default_tensor_engine) { + EXPECT_EQ(EngineOrFactory::get().to_string(), "DefaultTensorEngine"); +} + +TEST(EngineOrFactoryTest, set_with_same_value_is_allowed) { + EngineOrFactory::set(DefaultTensorEngine::ref()); +} + +TEST(EngineOrFactoryTest, set_with_another_value_is_not_allowed) { + EXPECT_THROW(EngineOrFactory::set(FastValueBuilderFactory::get()), vespalib::IllegalStateException); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/eval/engine_or_factory.cpp b/eval/src/vespa/eval/eval/engine_or_factory.cpp index d297cc91f6f..67885ede48e 100644 --- a/eval/src/vespa/eval/eval/engine_or_factory.cpp +++ b/eval/src/vespa/eval/eval/engine_or_factory.cpp @@ -2,11 +2,29 @@ #include "engine_or_factory.h" #include "fast_value.h" -#include <vespa/eval/tensor/default_tensor_engine.h> +#include "simple_value.h" #include "value_codec.h" +#include "simple_tensor_engine.h" +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/tensor/default_value_builder_factory.h> +#include <vespa/eval/tensor/mixed/packed_mixed_tensor_builder_factory.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/util/stringfmt.h> + +using vespalib::make_string_short::fmt; namespace vespalib::eval { +EngineOrFactory EngineOrFactory::_default{tensor::DefaultTensorEngine::ref()}; + + +EngineOrFactory +EngineOrFactory::get_shared(EngineOrFactory hint) +{ + static EngineOrFactory shared{hint}; + return shared; +} + const TensorFunction & EngineOrFactory::optimize(const TensorFunction &expr, Stash &stash) const { if (is_engine()) { @@ -88,4 +106,49 @@ EngineOrFactory::rename(const Value &a, const std::vector<vespalib::string> &fro return engine().rename(a, from, to, stash); } +void +EngineOrFactory::set(EngineOrFactory wanted) +{ + auto engine = get_shared(wanted); + if (engine._value != wanted._value) { + auto msg = fmt("EngineOrFactory: trying to set implementation to [%s] when [%s] is already in use", + wanted.to_string().c_str(), engine.to_string().c_str()); + throw IllegalStateException(msg); + } +} + +EngineOrFactory +EngineOrFactory::get() +{ + return get_shared(_default); +} + +vespalib::string +EngineOrFactory::to_string() const +{ + if (is_engine()) { + if (&engine() == &tensor::DefaultTensorEngine::ref()) { + return "DefaultTensorEngine"; + } + if (&engine() == &SimpleTensorEngine::ref()) { + return "SimpleTensorEngine"; + } + } + if (is_factory()) { + if (&factory() == &FastValueBuilderFactory::get()) { + return "FastValueBuilderFactory"; + } + if (&factory() == &SimpleValueBuilderFactory::get()) { + return "SimpleValueBuilderFactory"; + } + if (&factory() == &tensor::DefaultValueBuilderFactory::get()) { + return "DefaultValueBuilderFactory"; + } + if (&factory() == &PackedMixedTensorBuilderFactory::get()) { + return "PackedMixedTensorBuilderFactory"; + } + } + return "???"; +} + } diff --git a/eval/src/vespa/eval/eval/engine_or_factory.h b/eval/src/vespa/eval/eval/engine_or_factory.h index a95bfda6bd5..e94febd54ea 100644 --- a/eval/src/vespa/eval/eval/engine_or_factory.h +++ b/eval/src/vespa/eval/eval/engine_or_factory.h @@ -32,6 +32,8 @@ private: using engine_t = const TensorEngine *; using factory_t = const ValueBuilderFactory *; std::variant<engine_t,factory_t> _value; + static EngineOrFactory _default; + static EngineOrFactory get_shared(EngineOrFactory hint); public: EngineOrFactory(const TensorEngine &engine_in) : _value(&engine_in) {} EngineOrFactory(const ValueBuilderFactory &factory_in) : _value(&factory_in) {} @@ -52,6 +54,11 @@ public: const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const; const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const; const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const; + // global switch with default; call set before get to override the default + static void set(EngineOrFactory wanted); + static EngineOrFactory get(); + // try to describe the value held by this object as a human-readable string + vespalib::string to_string() const; }; } |