aboutsummaryrefslogtreecommitdiffstats
path: root/searchcore/src/vespa/searchcore/bmcluster/storage_api_rpc_bm_feed_handler.cpp
blob: ee4d980546b884c84cf2100828bded6b303dfc7e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "storage_api_rpc_bm_feed_handler.h"
#include "i_bm_distribution.h"
#include "pending_tracker.h"
#include "pending_tracker_hash.h"
#include "storage_reply_error_checker.h"
#include <vespa/storageapi/messageapi/storagecommand.h>
#include <vespa/storage/storageserver/message_dispatcher.h>
#include <vespa/storage/storageserver/rpc/message_codec_provider.h>
#include <vespa/storage/storageserver/rpc/shared_rpc_resources.h>
#include <vespa/storage/storageserver/rpc/storage_api_rpc_service.h>

using document::DocumentTypeRepo;
using storage::rpc::SharedRpcResources;
using storage::rpc::StorageApiRpcService;

namespace search::bmcluster {

class StorageApiRpcBmFeedHandler::MyMessageDispatcher : public storage::MessageDispatcher,
                                 public StorageReplyErrorChecker
{
    PendingTrackerHash _pending_hash;
public:
    MyMessageDispatcher()
        : storage::MessageDispatcher(),
          StorageReplyErrorChecker(),
          _pending_hash()
    {
    }
    ~MyMessageDispatcher() override;
    void dispatch_sync(std::shared_ptr<storage::api::StorageMessage> msg) override {
        check_error(*msg);
        release(msg->getMsgId());
    }
    void dispatch_async(std::shared_ptr<storage::api::StorageMessage> msg) override {
        check_error(*msg);
        release(msg->getMsgId());
    }
    void retain(uint64_t msg_id, PendingTracker &tracker) { _pending_hash.retain(msg_id, tracker); }
    void release(uint64_t msg_id) {
        auto tracker = _pending_hash.release(msg_id);
        if (tracker != nullptr) {
            tracker->release();
        } else {
            ++_errors;
        }
    }
};

StorageApiRpcBmFeedHandler::MyMessageDispatcher::~MyMessageDispatcher()
{
}

StorageApiRpcBmFeedHandler::StorageApiRpcBmFeedHandler(SharedRpcResources& shared_rpc_resources_in,
                                                       std::shared_ptr<const DocumentTypeRepo> repo,
                                                       const StorageApiRpcService::Params& rpc_params,
                                                       const IBmDistribution& distribution,
                                                       bool distributor)
    : StorageApiBmFeedHandlerBase("StorageApiRpcBmFeedHandler", distribution, distributor),
      _addresses(distribution.get_num_nodes(), distributor),
      _no_address_error_count(0u),
      _shared_rpc_resources(shared_rpc_resources_in),
      _message_dispatcher(std::make_unique<MyMessageDispatcher>()),
      _message_codec_provider(std::make_unique<storage::rpc::MessageCodecProvider>(repo)),
      _rpc_client(std::make_unique<storage::rpc::StorageApiRpcService>(*_message_dispatcher, _shared_rpc_resources, *_message_codec_provider, rpc_params))
{
}

StorageApiRpcBmFeedHandler::~StorageApiRpcBmFeedHandler() = default;

void
StorageApiRpcBmFeedHandler::send_cmd(std::shared_ptr<storage::api::StorageCommand> cmd, PendingTracker& tracker)
{
    uint32_t node_idx = route_cmd(*cmd);
    if (_addresses.has_address(node_idx)) {
        cmd->setAddress(_addresses.get_address(node_idx));
        _message_dispatcher->retain(cmd->getMsgId(), tracker);
        _rpc_client->send_rpc_v1_request(std::move(cmd));
    } else {
        ++_no_address_error_count;
    }
}

void
StorageApiRpcBmFeedHandler::attach_bucket_info_queue(PendingTracker&)
{
}

uint32_t
StorageApiRpcBmFeedHandler::get_error_count() const
{
    return _message_dispatcher->get_error_count() + _no_address_error_count;
}

}