// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include #include #include #include #include #include #include #include #include #include #include using namespace mbus; using vespalib::make_string; static const duration TIMEOUT = 120s; class StringList : public std::vector { public: StringList &add(const string &str); }; StringList & StringList::add(const string &str) { std::vector::push_back(str); return *this; } class CustomPolicyFactory : public SimpleProtocol::IPolicyFactory { private: friend class CustomPolicy; bool _forward; std::vector _expectedAll; std::vector _expectedMatched; public: CustomPolicyFactory(bool forward, const std::vector &all, const std::vector &matched); ~CustomPolicyFactory() override; IRoutingPolicy::UP create(const string ¶m) override; }; CustomPolicyFactory::~CustomPolicyFactory() = default; class CustomPolicy : public IRoutingPolicy { private: CustomPolicyFactory &_factory; public: explicit CustomPolicy(CustomPolicyFactory &factory); void select(RoutingContext &ctx) override; void merge(RoutingContext &ctx) override; }; CustomPolicy::CustomPolicy(CustomPolicyFactory &factory) : _factory(factory) { } void CustomPolicy::select(RoutingContext &ctx) { auto reply = std::make_unique(); reply->getTrace().setLevel(9); const std::vector &all = ctx.getAllRecipients(); if (_factory._expectedAll.size() == all.size()) { ctx.trace(1, make_string("Got %d expected recipients.", (uint32_t)all.size())); for (const auto & route : all) { if (find(_factory._expectedAll.begin(), _factory._expectedAll.end(), route.toString()) != _factory._expectedAll.end()) { ctx.trace(1, make_string("Got expected recipient '%s'.", route.toString().c_str())); } else { reply->addError(Error(ErrorCode::APP_FATAL_ERROR, make_string("Matched recipient '%s' not expected.", route.toString().c_str()))); } } } else { reply->addError(Error(ErrorCode::APP_FATAL_ERROR, make_string("Expected %d recipients, got %d.", (uint32_t)_factory._expectedAll.size(), (uint32_t)all.size()))); } if (ctx.getNumRecipients() == all.size()) { for (uint32_t i = 0; i < all.size(); ++i) { if (all[i].toString() == ctx.getRecipient(i).toString()) { ctx.trace(1, make_string("getRecipient(%d) matches getAllRecipients()[%d]", i, i)); } else { reply->addError(Error(ErrorCode::APP_FATAL_ERROR, make_string("getRecipient(%d) differs from getAllRecipients()[%d]", i, i))); } } } else { reply->addError(Error(ErrorCode::APP_FATAL_ERROR, "getNumRecipients() differs from getAllRecipients().size()")); } std::vector matched; ctx.getMatchedRecipients(matched); if (_factory._expectedMatched.size() == matched.size()) { ctx.trace(1, make_string("Got %d expected recipients.", (uint32_t)matched.size())); for (auto & route : matched) { if (find(_factory._expectedMatched.begin(), _factory._expectedMatched.end(), route.toString()) != _factory._expectedMatched.end()) { ctx.trace(1, make_string("Got matched recipient '%s'.", route.toString().c_str())); } else { reply->addError(Error(ErrorCode::APP_FATAL_ERROR, make_string("Matched recipient '%s' not expected.", route.toString().c_str()))); } } } else { reply->addError(Error(ErrorCode::APP_FATAL_ERROR, make_string("Expected %d matched recipients, got %d.", (uint32_t)_factory._expectedMatched.size(), (uint32_t)matched.size()))); } if (!reply->hasErrors() && _factory._forward) { for (auto & route : matched) { ctx.addChild(route); } } else { ctx.setReply(std::move(reply)); } } void CustomPolicy::merge(RoutingContext &ctx) { auto ret = std::make_unique(); for (RoutingNodeIterator it = ctx.getChildIterator(); it.isValid(); it.next()) { const Reply &reply = it.getReplyRef(); for (uint32_t i = 0; i < reply.getNumErrors(); ++i) { ret->addError(reply.getError(i)); } } ctx.setReply(std::move(ret)); } CustomPolicyFactory::CustomPolicyFactory(bool forward, const std::vector &all, const std::vector &matched) : _forward(forward), _expectedAll(all), _expectedMatched(matched) { // empty } IRoutingPolicy::UP CustomPolicyFactory::create(const string &) { return IRoutingPolicy::UP(new CustomPolicy(*this)); } Message::UP createMessage(const string &msg) { Message::UP ret(new SimpleMessage(msg)); ret->getTrace().setLevel(9); return ret; } //////////////////////////////////////////////////////////////////////////////// // // Setup // //////////////////////////////////////////////////////////////////////////////// class TestData { public: Slobrok _slobrok; RetryTransientErrorsPolicy::SP _retryPolicy; TestServer _srcServer; SourceSession::UP _srcSession; Receptor _srcHandler; TestServer _dstServer; DestinationSession::UP _dstSession; Receptor _dstHandler; public: TestData(); ~TestData(); bool start(); }; class Test : public vespalib::TestApp { private: static Message::UP createMessage(const string &msg); public: int Main() override; void testSingleDirective(TestData &data); void testMoreDirectives(TestData &data); void testRecipientsRemain(TestData &data); void testConstRoute(TestData &data); }; TEST_APPHOOK(Test); TestData::TestData() : _slobrok(), _retryPolicy(std::make_shared()), _srcServer(MessageBusParams().setRetryPolicy(_retryPolicy).addProtocol(std::make_shared()), RPCNetworkParams(_slobrok.config())), _srcSession(), _srcHandler(), _dstServer(MessageBusParams().addProtocol(std::make_shared()), RPCNetworkParams(_slobrok.config()).setIdentity(Identity("dst"))), _dstSession(), _dstHandler() { _retryPolicy->setBaseDelay(0); } TestData::~TestData() = default; bool TestData::start() { _srcSession = _srcServer.mb.createSourceSession(SourceSessionParams().setReplyHandler(_srcHandler)); if ( ! _srcSession) { return false; } _dstSession = _dstServer.mb.createDestinationSession(DestinationSessionParams().setName("session").setMessageHandler(_dstHandler)); if ( ! _dstSession) { return false; } if (!_srcServer.waitSlobrok("dst/session", 1u)) { return false; } return true; } Message::UP Test::createMessage(const string &msg) { auto ret = std::make_unique(msg); ret->getTrace().setLevel(9); return ret; } int Test::Main() { TEST_INIT("routingcontext_test"); TestData data; ASSERT_TRUE(data.start()); testSingleDirective(data); TEST_FLUSH(); testMoreDirectives(data); TEST_FLUSH(); testRecipientsRemain(data); TEST_FLUSH(); testConstRoute(data); TEST_FLUSH(); TEST_DONE(); } //////////////////////////////////////////////////////////////////////////////// // // Tests // //////////////////////////////////////////////////////////////////////////////// void Test::testSingleDirective(TestData &data) { IProtocol::SP protocol(new SimpleProtocol()); auto &simple = dynamic_cast(*protocol); simple.addPolicyFactory("Custom", SimpleProtocol::IPolicyFactory::SP(new CustomPolicyFactory( false, StringList().add("foo").add("bar").add("baz/cox"), StringList().add("foo").add("bar")))); data._srcServer.mb.putProtocol(protocol); data._srcServer.mb.setupRouting(RoutingSpec().addTable(RoutingTableSpec(SimpleProtocol::NAME) .addRoute(RouteSpec("myroute").addHop("myhop")) .addHop(HopSpec("myhop", "[Custom]") .addRecipient("foo") .addRecipient("bar") .addRecipient("baz/cox")))); for (uint32_t i = 0; i < 2; ++i) { EXPECT_TRUE(data._srcSession->send(createMessage("msg"), "myroute").isAccepted()); Reply::UP reply = data._srcHandler.getReply(); ASSERT_TRUE(reply); printf("%s", reply->getTrace().toString().c_str()); EXPECT_TRUE(!reply->hasErrors()); } } void Test::testMoreDirectives(TestData &data) { IProtocol::SP protocol(new SimpleProtocol()); auto &simple = dynamic_cast(*protocol); simple.addPolicyFactory("Custom", SimpleProtocol::IPolicyFactory::SP(new CustomPolicyFactory( false, StringList().add("foo").add("foo/bar").add("foo/bar0/baz").add("foo/bar1/baz").add("foo/bar/baz/cox"), StringList().add("foo/bar0/baz").add("foo/bar1/baz")))); data._srcServer.mb.putProtocol(protocol); data._srcServer.mb.setupRouting(RoutingSpec().addTable(RoutingTableSpec(SimpleProtocol::NAME) .addRoute(RouteSpec("myroute").addHop("myhop")) .addHop(HopSpec("myhop", "foo/[Custom]/baz") .addRecipient("foo") .addRecipient("foo/bar") .addRecipient("foo/bar0/baz") .addRecipient("foo/bar1/baz") .addRecipient("foo/bar/baz/cox")))); for (uint32_t i = 0; i < 2; ++i) { EXPECT_TRUE(data._srcSession->send(createMessage("msg"), "myroute").isAccepted()); Reply::UP reply = data._srcHandler.getReply(); ASSERT_TRUE(reply); printf("%s", reply->getTrace().toString().c_str()); EXPECT_TRUE(!reply->hasErrors()); } } void Test::testRecipientsRemain(TestData &data) { auto protocol = std::make_shared(); auto &simple = dynamic_cast(*protocol); simple.addPolicyFactory("First", std::make_shared(true, StringList().add("foo/bar"), StringList().add("foo/[Second]"))); simple.addPolicyFactory("Second", std::make_shared(false, StringList().add("foo/bar"), StringList().add("foo/bar"))); data._srcServer.mb.putProtocol(protocol); data._srcServer.mb.setupRouting(RoutingSpec().addTable(RoutingTableSpec(SimpleProtocol::NAME) .addRoute(RouteSpec("myroute").addHop("myhop")) .addHop(HopSpec("myhop", "[First]/[Second]") .addRecipient("foo/bar")))); for (uint32_t i = 0; i < 2; ++i) { EXPECT_TRUE(data._srcSession->send(createMessage("msg"), "myroute").isAccepted()); Reply::UP reply = data._srcHandler.getReply(); ASSERT_TRUE(reply); printf("%s", reply->getTrace().toString().c_str()); EXPECT_TRUE(!reply->hasErrors()); } } void Test::testConstRoute(TestData &data) { auto protocol = std::make_shared(); auto &simple = dynamic_cast(*protocol); simple.addPolicyFactory("DocumentRouteSelector", std::make_shared(true, StringList().add("dst"), StringList().add("dst"))); data._srcServer.mb.putProtocol(protocol); data._srcServer.mb.setupRouting(RoutingSpec().addTable(RoutingTableSpec(SimpleProtocol::NAME) .addRoute(RouteSpec("default").addHop("indexing")) .addHop(HopSpec("indexing", "[DocumentRouteSelector]").addRecipient("dst")) .addHop(HopSpec("dst", "dst/session")))); for (uint32_t i = 0; i < 2; ++i) { EXPECT_TRUE(data._srcSession->send(createMessage("msg"), Route::parse("route:default")).isAccepted()); Message::UP msg = data._dstHandler.getMessage(TIMEOUT); ASSERT_TRUE(msg); data._dstSession->acknowledge(std::move(msg)); Reply::UP reply = data._srcHandler.getReply(); ASSERT_TRUE(reply); printf("%s", reply->getTrace().toString().c_str()); EXPECT_TRUE(!reply->hasErrors()); } }