summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-10-13 15:58:31 +0200
committerGitHub <noreply@github.com>2020-10-13 15:58:31 +0200
commit1fb76b9e6dc6028f2678c1c01054f4d37bec6a04 (patch)
tree2d72ab896f51a9649deb896fc88a363891697432 /eval
parentdd5873778526bb92f53c6cc75b3f8945f86710f2 (diff)
parent80a1e55004ef23e16029603921b705a7258eac17 (diff)
Merge pull request #14842 from vespa-engine/havardpe/engine-or-factory-global-switch-with-default
global implementation switch with default
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/eval/engine_or_factory/CMakeLists.txt17
-rw-r--r--eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp25
-rw-r--r--eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp24
-rw-r--r--eval/src/vespa/eval/eval/engine_or_factory.cpp65
-rw-r--r--eval/src/vespa/eval/eval/engine_or_factory.h7
6 files changed, 138 insertions, 1 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index db178342359..eb94a9498cb 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -14,6 +14,7 @@ vespa_define_module(
src/tests/eval/aggr
src/tests/eval/compile_cache
src/tests/eval/compiled_function
+ src/tests/eval/engine_or_factory
src/tests/eval/fast_sparse_map
src/tests/eval/function
src/tests/eval/function_speed
diff --git a/eval/src/tests/eval/engine_or_factory/CMakeLists.txt b/eval/src/tests/eval/engine_or_factory/CMakeLists.txt
new file mode 100644
index 00000000000..f0bd0f63251
--- /dev/null
+++ b/eval/src/tests/eval/engine_or_factory/CMakeLists.txt
@@ -0,0 +1,17 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_engine_or_factory_test_app TEST
+ SOURCES
+ engine_or_factory_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_engine_or_factory_test_app COMMAND eval_engine_or_factory_test_app)
+vespa_add_executable(eval_engine_or_factory_override_test_app TEST
+ SOURCES
+ engine_or_factory_override_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_engine_or_factory_override_test_app COMMAND eval_engine_or_factory_override_test_app)
diff --git a/eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp b/eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp
new file mode 100644
index 00000000000..8480bb7d39e
--- /dev/null
+++ b/eval/src/tests/eval/engine_or_factory/engine_or_factory_override_test.cpp
@@ -0,0 +1,25 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/vespalib/util/exceptions.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::tensor;
+
+TEST(EngineOrFactoryOverrideTest, set_can_override_get_result) {
+ EngineOrFactory::set(FastValueBuilderFactory::get());
+ EXPECT_EQ(EngineOrFactory::get().to_string(), "FastValueBuilderFactory");
+}
+
+TEST(EngineOrFactoryOverrideTest, set_with_same_value_is_allowed) {
+ EngineOrFactory::set(FastValueBuilderFactory::get());
+}
+
+TEST(EngineOrFactoryOverrideTest, set_with_another_value_is_not_allowed) {
+ EXPECT_THROW(EngineOrFactory::set(DefaultTensorEngine::ref()), vespalib::IllegalStateException);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp b/eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp
new file mode 100644
index 00000000000..6cb9cc0c89c
--- /dev/null
+++ b/eval/src/tests/eval/engine_or_factory/engine_or_factory_test.cpp
@@ -0,0 +1,24 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/vespalib/util/exceptions.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::tensor;
+
+TEST(EngineOrFactoryTest, default_is_default_tensor_engine) {
+ EXPECT_EQ(EngineOrFactory::get().to_string(), "DefaultTensorEngine");
+}
+
+TEST(EngineOrFactoryTest, set_with_same_value_is_allowed) {
+ EngineOrFactory::set(DefaultTensorEngine::ref());
+}
+
+TEST(EngineOrFactoryTest, set_with_another_value_is_not_allowed) {
+ EXPECT_THROW(EngineOrFactory::set(FastValueBuilderFactory::get()), vespalib::IllegalStateException);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/eval/engine_or_factory.cpp b/eval/src/vespa/eval/eval/engine_or_factory.cpp
index d297cc91f6f..67885ede48e 100644
--- a/eval/src/vespa/eval/eval/engine_or_factory.cpp
+++ b/eval/src/vespa/eval/eval/engine_or_factory.cpp
@@ -2,11 +2,29 @@
#include "engine_or_factory.h"
#include "fast_value.h"
-#include <vespa/eval/tensor/default_tensor_engine.h>
+#include "simple_value.h"
#include "value_codec.h"
+#include "simple_tensor_engine.h"
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/eval/tensor/default_value_builder_factory.h>
+#include <vespa/eval/tensor/mixed/packed_mixed_tensor_builder_factory.h>
+#include <vespa/vespalib/util/exceptions.h>
+#include <vespa/vespalib/util/stringfmt.h>
+
+using vespalib::make_string_short::fmt;
namespace vespalib::eval {
+EngineOrFactory EngineOrFactory::_default{tensor::DefaultTensorEngine::ref()};
+
+
+EngineOrFactory
+EngineOrFactory::get_shared(EngineOrFactory hint)
+{
+ static EngineOrFactory shared{hint};
+ return shared;
+}
+
const TensorFunction &
EngineOrFactory::optimize(const TensorFunction &expr, Stash &stash) const {
if (is_engine()) {
@@ -88,4 +106,49 @@ EngineOrFactory::rename(const Value &a, const std::vector<vespalib::string> &fro
return engine().rename(a, from, to, stash);
}
+void
+EngineOrFactory::set(EngineOrFactory wanted)
+{
+ auto engine = get_shared(wanted);
+ if (engine._value != wanted._value) {
+ auto msg = fmt("EngineOrFactory: trying to set implementation to [%s] when [%s] is already in use",
+ wanted.to_string().c_str(), engine.to_string().c_str());
+ throw IllegalStateException(msg);
+ }
+}
+
+EngineOrFactory
+EngineOrFactory::get()
+{
+ return get_shared(_default);
+}
+
+vespalib::string
+EngineOrFactory::to_string() const
+{
+ if (is_engine()) {
+ if (&engine() == &tensor::DefaultTensorEngine::ref()) {
+ return "DefaultTensorEngine";
+ }
+ if (&engine() == &SimpleTensorEngine::ref()) {
+ return "SimpleTensorEngine";
+ }
+ }
+ if (is_factory()) {
+ if (&factory() == &FastValueBuilderFactory::get()) {
+ return "FastValueBuilderFactory";
+ }
+ if (&factory() == &SimpleValueBuilderFactory::get()) {
+ return "SimpleValueBuilderFactory";
+ }
+ if (&factory() == &tensor::DefaultValueBuilderFactory::get()) {
+ return "DefaultValueBuilderFactory";
+ }
+ if (&factory() == &PackedMixedTensorBuilderFactory::get()) {
+ return "PackedMixedTensorBuilderFactory";
+ }
+ }
+ return "???";
+}
+
}
diff --git a/eval/src/vespa/eval/eval/engine_or_factory.h b/eval/src/vespa/eval/eval/engine_or_factory.h
index a95bfda6bd5..e94febd54ea 100644
--- a/eval/src/vespa/eval/eval/engine_or_factory.h
+++ b/eval/src/vespa/eval/eval/engine_or_factory.h
@@ -32,6 +32,8 @@ private:
using engine_t = const TensorEngine *;
using factory_t = const ValueBuilderFactory *;
std::variant<engine_t,factory_t> _value;
+ static EngineOrFactory _default;
+ static EngineOrFactory get_shared(EngineOrFactory hint);
public:
EngineOrFactory(const TensorEngine &engine_in) : _value(&engine_in) {}
EngineOrFactory(const ValueBuilderFactory &factory_in) : _value(&factory_in) {}
@@ -52,6 +54,11 @@ public:
const Value &reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const;
const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const;
const Value &rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const;
+ // global switch with default; call set before get to override the default
+ static void set(EngineOrFactory wanted);
+ static EngineOrFactory get();
+ // try to describe the value held by this object as a human-readable string
+ vespalib::string to_string() const;
};
}