aboutsummaryrefslogtreecommitdiffstats
path: root/documentapi/src/tests/policyfactory/policyfactory.cpp
blob: e28c27c3da146a2e4c03d964275f75ce059caba0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/documentapi/messagebus/iroutingpolicyfactory.h>
#include <vespa/document/repo/documenttyperepo.h>
#include <vespa/documentapi/messagebus/documentprotocol.h>
#include <vespa/documentapi/messagebus/messages/removedocumentmessage.h>
#include <vespa/messagebus/testlib/receptor.h>
#include <vespa/messagebus/testlib/slobrok.h>
#include <vespa/messagebus/testlib/testserver.h>
#include <vespa/vespalib/testkit/testapp.h>

using document::DocumentTypeRepo;
using namespace documentapi;

///////////////////////////////////////////////////////////////////////////////
//
// Utilities
//
///////////////////////////////////////////////////////////////////////////////

class MyPolicy : public mbus::IRoutingPolicy {
private:
    string _param;
public:
    MyPolicy(const string &param);
    void select(mbus::RoutingContext &ctx) override;
    void merge(mbus::RoutingContext &ctx) override;
};

MyPolicy::MyPolicy(const string &param) :
    _param(param)
{
    // empty
}

void
MyPolicy::select(mbus::RoutingContext &ctx)
{
    ctx.setError(DocumentProtocol::ERROR_POLICY_FAILURE, _param);
}

void
MyPolicy::merge(mbus::RoutingContext &ctx)
{
    (void)ctx;
    ASSERT_TRUE(false);
}

class MyFactory : public IRoutingPolicyFactory {
public:
    mbus::IRoutingPolicy::UP createPolicy(const string &param) const override;
};

mbus::IRoutingPolicy::UP
MyFactory::createPolicy(const string &param) const
{
    return std::make_unique<MyPolicy>(param);
}

mbus::Message::UP
createMessage()
{
    auto ret = std::make_unique<RemoveDocumentMessage>(document::DocumentId("id:ns:type::"));
    ret->getTrace().setLevel(9);
    return ret;
}

///////////////////////////////////////////////////////////////////////////////
//
// Tests
//
///////////////////////////////////////////////////////////////////////////////

TEST_SETUP(Test);

const vespalib::duration TIMEOUT = 600s;

int
Test::Main()
{
    TEST_INIT("policyfactory_test");

    std::shared_ptr<const DocumentTypeRepo> repo(new DocumentTypeRepo);
    mbus::Slobrok slobrok;
    mbus::TestServer
        srv(mbus::MessageBusParams().addProtocol(std::make_shared<DocumentProtocol>(repo)),
            mbus::RPCNetworkParams(slobrok.config()));
    mbus::Receptor handler;
    mbus::SourceSession::UP src = srv.mb.createSourceSession(mbus::SourceSessionParams().setReplyHandler(handler));

    mbus::Route route = mbus::Route::parse("[MyPolicy]");
    ASSERT_TRUE(src->send(createMessage(), route).isAccepted());
    mbus::Reply::UP reply = static_cast<mbus::Receptor&>(src->getReplyHandler()).getReply(TIMEOUT);
    ASSERT_TRUE(reply);
    fprintf(stderr, "%s", reply->getTrace().toString().c_str());
    EXPECT_EQUAL(1u, reply->getNumErrors());
    EXPECT_EQUAL((uint32_t)mbus::ErrorCode::UNKNOWN_POLICY, reply->getError(0).getCode());

    mbus::IProtocol * obj = srv.mb.getProtocol(DocumentProtocol::NAME);
    DocumentProtocol * protocol = dynamic_cast<DocumentProtocol*>(obj);
    ASSERT_TRUE(protocol != nullptr);
    protocol->putRoutingPolicyFactory("MyPolicy", std::make_shared<MyFactory>());

    ASSERT_TRUE(src->send(createMessage(), route).isAccepted());
    reply = static_cast<mbus::Receptor&>(src->getReplyHandler()).getReply(TIMEOUT);
    ASSERT_TRUE(reply);
    fprintf(stderr, "%s", reply->getTrace().toString().c_str());
    EXPECT_EQUAL(1u, reply->getNumErrors());
    EXPECT_EQUAL((uint32_t)DocumentProtocol::ERROR_POLICY_FAILURE, reply->getError(0).getCode());

    TEST_DONE();
}