summaryrefslogtreecommitdiffstats
path: root/fnet
diff options
context:
space:
mode:
authorTor Brede Vekterli <vekterli@yahooinc.com>2022-06-22 15:44:57 +0000
committerTor Brede Vekterli <vekterli@yahooinc.com>2022-06-29 11:20:24 +0000
commitcc44b799f0d78a5e26f12ecb8b868301095570c4 (patch)
tree374f50996663fbdfa85d529202c0e7cccb99648d /fnet
parentcbe98d69506bf60f7fcf7681eb99a79589300882 (diff)
Support mTLS connection-level capabilities and RPC access filtering in C++
Adds the following: * Named capabilities and capability sets that represent (respectively) a single Vespa access API (such as Document API, search API etc) or a concrete subset of individual capabilities that make up a particular Vespa service (such as a content node). * A new `capabilities` array field to the mTLS authorization policies that allows for constraining what requests sent over a particular connection are allowed to actually do. Capabilities are referenced by name and may include any combination of capability sets and individual capabilities. If multiple capabilities/sets are configured, the resulting set of capabilities is the union set of all of them. * An FRT RPC-level access filter that can be set up as part of RPC method definitions. If set, filters are invoked prior to RPC methods. * A new `PERMISSION_DENIED` error code to FRT RPC that is invoked if an access filter denies a request. This also GCs the unused `AssumedRoles` concept which is now deprecated in favor of capabilities. Note: this is **not yet** a public or stable API, and capability names/semantics may change at any time.
Diffstat (limited to 'fnet')
-rw-r--r--fnet/src/tests/frt/rpc/invoke.cpp55
-rw-r--r--fnet/src/tests/info/info.cpp2
-rw-r--r--fnet/src/vespa/fnet/connection.cpp10
-rw-r--r--fnet/src/vespa/fnet/connection.h14
-rw-r--r--fnet/src/vespa/fnet/frt/CMakeLists.txt1
-rw-r--r--fnet/src/vespa/fnet/frt/error.cpp54
-rw-r--r--fnet/src/vespa/fnet/frt/error.h31
-rw-r--r--fnet/src/vespa/fnet/frt/invoker.cpp6
-rw-r--r--fnet/src/vespa/fnet/frt/reflection.cpp17
-rw-r--r--fnet/src/vespa/fnet/frt/reflection.h9
-rw-r--r--fnet/src/vespa/fnet/frt/request_access_filter.h24
-rw-r--r--fnet/src/vespa/fnet/frt/require_capability.cpp13
-rw-r--r--fnet/src/vespa/fnet/frt/require_capability.h21
13 files changed, 207 insertions, 50 deletions
diff --git a/fnet/src/tests/frt/rpc/invoke.cpp b/fnet/src/tests/frt/rpc/invoke.cpp
index 1fbd356b239..e1912985379 100644
--- a/fnet/src/tests/frt/rpc/invoke.cpp
+++ b/fnet/src/tests/frt/rpc/invoke.cpp
@@ -7,8 +7,10 @@
#include <vespa/fnet/frt/target.h>
#include <vespa/fnet/frt/rpcrequest.h>
#include <vespa/fnet/frt/invoker.h>
+#include <vespa/fnet/frt/request_access_filter.h>
#include <mutex>
#include <condition_variable>
+#include <string_view>
using vespalib::SocketSpec;
using vespalib::BenchmarkTimer;
@@ -175,11 +177,25 @@ public:
//-------------------------------------------------------------
+struct MyAccessFilter : FRT_RequestAccessFilter {
+ ~MyAccessFilter() override = default;
+
+ constexpr static std::string_view WRONG_KEY = "...mellon!";
+ constexpr static std::string_view CORRECT_KEY = "let me in, I have cake";
+
+ bool allow(FRT_RPCRequest& req) const noexcept override {
+ const auto& req_param = req.GetParams()->GetValue(0)._string;
+ const auto magic_key = std::string_view(req_param._str, req_param._len);
+ return (magic_key == CORRECT_KEY);
+ }
+};
+
class TestRPC : public FRT_Invokable
{
private:
- uint32_t _intValue;
- RequestLatch _detached_req;
+ uint32_t _intValue;
+ RequestLatch _detached_req;
+ std::atomic<bool> _restricted_method_was_invoked;
TestRPC(const TestRPC &);
TestRPC &operator=(const TestRPC &);
@@ -187,7 +203,8 @@ private:
public:
TestRPC(FRT_Supervisor *supervisor)
: _intValue(0),
- _detached_req()
+ _detached_req(),
+ _restricted_method_was_invoked(false)
{
FRT_ReflectionBuilder rb(supervisor);
@@ -201,6 +218,9 @@ public:
FRT_METHOD(TestRPC::RPC_GetValue), this);
rb.DefineMethod("test", "iibb", "i",
FRT_METHOD(TestRPC::RPC_Test), this);
+ rb.DefineMethod("accessRestricted", "s", "",
+ FRT_METHOD(TestRPC::RPC_AccessRestricted), this);
+ rb.RequestAccessFilter(std::make_unique<MyAccessFilter>());
}
void RPC_Test(FRT_RPCRequest *req)
@@ -244,6 +264,16 @@ public:
req->GetReturn()->AddInt32(_intValue);
}
+ void RPC_AccessRestricted([[maybe_unused]] FRT_RPCRequest *req)
+ {
+ // We'll only get here if the access filter lets us in
+ _restricted_method_was_invoked.store(true);
+ }
+
+ bool restricted_method_was_invoked() const noexcept {
+ return _restricted_method_was_invoked.load();
+ }
+
RequestLatch &detached_req() { return _detached_req; }
};
@@ -264,6 +294,7 @@ public:
FRT_Target *make_bad_target() { return _client.supervisor().GetTarget("bogus address"); }
RequestLatch &detached_req() { return _testRPC.detached_req(); }
EchoTest &echo() { return _echoTest; }
+ const TestRPC& server_instance() const noexcept { return _testRPC; }
Fixture()
: _client(crypto),
@@ -421,6 +452,24 @@ TEST_F("require that parameters can be echoed as return values", Fixture()) {
EXPECT_TRUE(req.get().GetParams()->Equals(req.get().GetReturn()));
}
+TEST_F("request denied by access filter returns PERMISSION_DENIED and does not invoke server method", Fixture()) {
+ MyReq req("accessRestricted");
+ auto key = MyAccessFilter::WRONG_KEY;
+ req.get().GetParams()->AddString(key.data(), key.size());
+ f1.target().InvokeSync(req.borrow(), timeout);
+ EXPECT_EQUAL(req.get().GetErrorCode(), FRTE_RPC_PERMISSION_DENIED);
+ EXPECT_FALSE(f1.server_instance().restricted_method_was_invoked());
+}
+
+TEST_F("request allowed by access filter invokes server method as usual", Fixture()) {
+ MyReq req("accessRestricted");
+ auto key = MyAccessFilter::CORRECT_KEY;
+ req.get().GetParams()->AddString(key.data(), key.size());
+ f1.target().InvokeSync(req.borrow(), timeout);
+ ASSERT_FALSE(req.get().IsError());
+ EXPECT_TRUE(f1.server_instance().restricted_method_was_invoked());
+}
+
TEST_MAIN() {
crypto = my_crypto_engine();
TEST_RUN_ALL();
diff --git a/fnet/src/tests/info/info.cpp b/fnet/src/tests/info/info.cpp
index 0d4e0f90a09..4271546e647 100644
--- a/fnet/src/tests/info/info.cpp
+++ b/fnet/src/tests/info/info.cpp
@@ -80,7 +80,7 @@ TEST("size of important objects")
EXPECT_EQUAL(MUTEX_SIZE + sizeof(std::string) + 112u, sizeof(FNET_IOComponent));
EXPECT_EQUAL(32u, sizeof(FNET_Channel));
EXPECT_EQUAL(40u, sizeof(FNET_PacketQueue_NoLock));
- EXPECT_EQUAL(MUTEX_SIZE + sizeof(std::string) + 408u, sizeof(FNET_Connection));
+ EXPECT_EQUAL(MUTEX_SIZE + sizeof(std::string) + 416u, sizeof(FNET_Connection));
EXPECT_EQUAL(48u, sizeof(std::condition_variable));
EXPECT_EQUAL(56u, sizeof(FNET_DataBuffer));
EXPECT_EQUAL(8u, sizeof(FNET_Context));
diff --git a/fnet/src/vespa/fnet/connection.cpp b/fnet/src/vespa/fnet/connection.cpp
index 2677445e35d..26367c904b2 100644
--- a/fnet/src/vespa/fnet/connection.cpp
+++ b/fnet/src/vespa/fnet/connection.cpp
@@ -9,6 +9,7 @@
#include "config.h"
#include "transport_thread.h"
#include "transport.h"
+#include <vespa/vespalib/net/connection_auth_context.h>
#include <vespa/vespalib/net/socket_spec.h>
#include <vespa/log/log.h>
@@ -241,6 +242,8 @@ FNET_Connection::handshake()
break;
case vespalib::CryptoSocket::HandshakeResult::DONE: {
LOG(debug, "Connection(%s): handshake done with peer %s", GetSpec(), GetPeerSpec().c_str());
+ _auth_context = _socket->make_auth_context();
+ assert(_auth_context);
EnableReadEvent(true);
EnableWriteEvent(writePendingAfterConnect());
_flags._framed = (_socket->min_read_buffer_size() > 1);
@@ -764,3 +767,10 @@ FNET_Connection::GetPeerSpec() const
{
return vespalib::SocketAddress::peer_address(_socket->get_fd()).spec();
}
+
+const vespalib::net::ConnectionAuthContext&
+FNET_Connection::auth_context() const noexcept
+{
+ assert(_auth_context);
+ return *_auth_context;
+}
diff --git a/fnet/src/vespa/fnet/connection.h b/fnet/src/vespa/fnet/connection.h
index 15150ffbb07..4d66f22ce2b 100644
--- a/fnet/src/vespa/fnet/connection.h
+++ b/fnet/src/vespa/fnet/connection.h
@@ -18,6 +18,8 @@ class FNET_IPacketStreamer;
class FNET_IServerAdapter;
class FNET_IPacketHandler;
+namespace vespalib::net { class ConnectionAuthContext; }
+
/**
* Interface implemented by objects that want to perform connection
* cleanup. Use the SetCleanupHandler method to register with a
@@ -96,7 +98,7 @@ private:
using ResolveHandlerSP = std::shared_ptr<ResolveHandler>;
FNET_IPacketStreamer *_streamer; // custom packet streamer
FNET_IServerAdapter *_serverAdapter; // only on server side
- vespalib::CryptoSocket::UP _socket; // socket for this conn
+ vespalib::CryptoSocket::UP _socket; // socket for this conn
ResolveHandlerSP _resolve_handler; // async resolve callback
FNET_Context _context; // connection context
std::atomic<State> _state; // connection state. May be polled outside lock
@@ -115,6 +117,8 @@ private:
FNET_IConnectionCleanupHandler *_cleanup; // cleanup handler
+ std::unique_ptr<vespalib::net::ConnectionAuthContext> _auth_context;
+
static std::atomic<uint64_t> _num_connections; // total number of connections
@@ -277,7 +281,7 @@ public:
/**
* Destructor.
**/
- ~FNET_Connection();
+ ~FNET_Connection() override;
/**
@@ -504,6 +508,12 @@ public:
uint32_t getInputBufferSize() const { return _input.GetBufSize(); }
/**
+ * Returns the connection's auth context. Must only be called _after_ the
+ * handshake phase has completed.
+ */
+ const vespalib::net::ConnectionAuthContext& auth_context() const noexcept;
+
+ /**
* @return the total number of connection objects
**/
static uint64_t get_num_connections() {
diff --git a/fnet/src/vespa/fnet/frt/CMakeLists.txt b/fnet/src/vespa/fnet/frt/CMakeLists.txt
index c7bcbe27041..fa9623b950a 100644
--- a/fnet/src/vespa/fnet/frt/CMakeLists.txt
+++ b/fnet/src/vespa/fnet/frt/CMakeLists.txt
@@ -5,6 +5,7 @@ vespa_add_library(fnet_frt OBJECT
invoker.cpp
packets.cpp
reflection.cpp
+ require_capability.cpp
rpcrequest.cpp
supervisor.cpp
target.cpp
diff --git a/fnet/src/vespa/fnet/frt/error.cpp b/fnet/src/vespa/fnet/frt/error.cpp
index 6af9ea39757..fb91924bf35 100644
--- a/fnet/src/vespa/fnet/frt/error.cpp
+++ b/fnet/src/vespa/fnet/frt/error.cpp
@@ -12,19 +12,20 @@ FRT_GetErrorCodeName(uint32_t errorCode)
errorCode <= FRTE_RPC_LAST)
{
switch (errorCode) {
- case FRTE_RPC_GENERAL_ERROR: return "FRTE_RPC_GENERAL_ERROR";
- case FRTE_RPC_NOT_IMPLEMENTED: return "FRTE_RPC_NOT_IMPLEMENTED";
- case FRTE_RPC_ABORT: return "FRTE_RPC_ABORT";
- case FRTE_RPC_TIMEOUT: return "FRTE_RPC_TIMEOUT";
- case FRTE_RPC_CONNECTION: return "FRTE_RPC_CONNECTION";
- case FRTE_RPC_BAD_REQUEST: return "FRTE_RPC_BAD_REQUEST";
- case FRTE_RPC_NO_SUCH_METHOD: return "FRTE_RPC_NO_SUCH_METHOD";
- case FRTE_RPC_WRONG_PARAMS: return "FRTE_RPC_WRONG_PARAMS";
- case FRTE_RPC_OVERLOAD: return "FRTE_RPC_OVERLOAD";
- case FRTE_RPC_WRONG_RETURN: return "FRTE_RPC_WRONG_RETURN";
- case FRTE_RPC_BAD_REPLY: return "FRTE_RPC_BAD_REPLY";
- case FRTE_RPC_METHOD_FAILED: return "FRTE_RPC_METHOD_FAILED";
- default: return "[UNKNOWN RPC ERROR]";
+ case FRTE_RPC_GENERAL_ERROR: return "FRTE_RPC_GENERAL_ERROR";
+ case FRTE_RPC_NOT_IMPLEMENTED: return "FRTE_RPC_NOT_IMPLEMENTED";
+ case FRTE_RPC_ABORT: return "FRTE_RPC_ABORT";
+ case FRTE_RPC_TIMEOUT: return "FRTE_RPC_TIMEOUT";
+ case FRTE_RPC_CONNECTION: return "FRTE_RPC_CONNECTION";
+ case FRTE_RPC_BAD_REQUEST: return "FRTE_RPC_BAD_REQUEST";
+ case FRTE_RPC_NO_SUCH_METHOD: return "FRTE_RPC_NO_SUCH_METHOD";
+ case FRTE_RPC_WRONG_PARAMS: return "FRTE_RPC_WRONG_PARAMS";
+ case FRTE_RPC_OVERLOAD: return "FRTE_RPC_OVERLOAD";
+ case FRTE_RPC_WRONG_RETURN: return "FRTE_RPC_WRONG_RETURN";
+ case FRTE_RPC_BAD_REPLY: return "FRTE_RPC_BAD_REPLY";
+ case FRTE_RPC_METHOD_FAILED: return "FRTE_RPC_METHOD_FAILED";
+ case FRTE_RPC_PERMISSION_DENIED: return "FRTE_RPC_PERMISSION_DENIED";
+ default: return "[UNKNOWN RPC ERROR]";
}
}
return "[UNKNOWN ERROR]";
@@ -41,19 +42,20 @@ FRT_GetDefaultErrorMessage(uint32_t errorCode)
errorCode <= FRTE_RPC_LAST)
{
switch (errorCode) {
- case FRTE_RPC_GENERAL_ERROR: return "(RPC) General error";
- case FRTE_RPC_NOT_IMPLEMENTED: return "(RPC) Not implemented";
- case FRTE_RPC_ABORT: return "(RPC) Invocation aborted";
- case FRTE_RPC_TIMEOUT: return "(RPC) Invocation timed out";
- case FRTE_RPC_CONNECTION: return "(RPC) Connection error";
- case FRTE_RPC_BAD_REQUEST: return "(RPC) Bad request packet";
- case FRTE_RPC_NO_SUCH_METHOD: return "(RPC) No such method";
- case FRTE_RPC_WRONG_PARAMS: return "(RPC) Illegal parameters";
- case FRTE_RPC_OVERLOAD: return "(RPC) Request dropped due to server overload";
- case FRTE_RPC_WRONG_RETURN: return "(RPC) Illegal return values";
- case FRTE_RPC_BAD_REPLY: return "(RPC) Bad reply packet";
- case FRTE_RPC_METHOD_FAILED: return "(RPC) Method failed";
- default: return "[UNKNOWN RPC ERROR]";
+ case FRTE_RPC_GENERAL_ERROR: return "(RPC) General error";
+ case FRTE_RPC_NOT_IMPLEMENTED: return "(RPC) Not implemented";
+ case FRTE_RPC_ABORT: return "(RPC) Invocation aborted";
+ case FRTE_RPC_TIMEOUT: return "(RPC) Invocation timed out";
+ case FRTE_RPC_CONNECTION: return "(RPC) Connection error";
+ case FRTE_RPC_BAD_REQUEST: return "(RPC) Bad request packet";
+ case FRTE_RPC_NO_SUCH_METHOD: return "(RPC) No such method";
+ case FRTE_RPC_WRONG_PARAMS: return "(RPC) Illegal parameters";
+ case FRTE_RPC_OVERLOAD: return "(RPC) Request dropped due to server overload";
+ case FRTE_RPC_WRONG_RETURN: return "(RPC) Illegal return values";
+ case FRTE_RPC_BAD_REPLY: return "(RPC) Bad reply packet";
+ case FRTE_RPC_METHOD_FAILED: return "(RPC) Method failed";
+ case FRTE_RPC_PERMISSION_DENIED: return "(RPC) Permission denied";
+ default: return "[UNKNOWN RPC ERROR]";
}
}
return "[UNKNOWN ERROR]";
diff --git a/fnet/src/vespa/fnet/frt/error.h b/fnet/src/vespa/fnet/frt/error.h
index c5acfb744f6..7b3cdc7320b 100644
--- a/fnet/src/vespa/fnet/frt/error.h
+++ b/fnet/src/vespa/fnet/frt/error.h
@@ -4,21 +4,22 @@
#include <cstdint>
enum {
- FRTE_NO_ERROR = 0,
- FRTE_RPC_FIRST = 100,
- FRTE_RPC_GENERAL_ERROR = 100,
- FRTE_RPC_NOT_IMPLEMENTED = 101,
- FRTE_RPC_ABORT = 102,
- FRTE_RPC_TIMEOUT = 103,
- FRTE_RPC_CONNECTION = 104,
- FRTE_RPC_BAD_REQUEST = 105,
- FRTE_RPC_NO_SUCH_METHOD = 106,
- FRTE_RPC_WRONG_PARAMS = 107,
- FRTE_RPC_OVERLOAD = 108,
- FRTE_RPC_WRONG_RETURN = 109,
- FRTE_RPC_BAD_REPLY = 110,
- FRTE_RPC_METHOD_FAILED = 111,
- FRTE_RPC_LAST = 199
+ FRTE_NO_ERROR = 0,
+ FRTE_RPC_FIRST = 100,
+ FRTE_RPC_GENERAL_ERROR = 100,
+ FRTE_RPC_NOT_IMPLEMENTED = 101,
+ FRTE_RPC_ABORT = 102,
+ FRTE_RPC_TIMEOUT = 103,
+ FRTE_RPC_CONNECTION = 104,
+ FRTE_RPC_BAD_REQUEST = 105,
+ FRTE_RPC_NO_SUCH_METHOD = 106,
+ FRTE_RPC_WRONG_PARAMS = 107,
+ FRTE_RPC_OVERLOAD = 108,
+ FRTE_RPC_WRONG_RETURN = 109,
+ FRTE_RPC_BAD_REPLY = 110,
+ FRTE_RPC_METHOD_FAILED = 111,
+ FRTE_RPC_PERMISSION_DENIED = 112,
+ FRTE_RPC_LAST = 199
};
const char *FRT_GetErrorCodeName(uint32_t errorCode);
diff --git a/fnet/src/vespa/fnet/frt/invoker.cpp b/fnet/src/vespa/fnet/frt/invoker.cpp
index 85eae6cb41a..f75526d51f1 100644
--- a/fnet/src/vespa/fnet/frt/invoker.cpp
+++ b/fnet/src/vespa/fnet/frt/invoker.cpp
@@ -52,6 +52,7 @@ FRT_RPCInvoker::FRT_RPCInvoker(FRT_Supervisor *supervisor,
std::string methodName(_req->GetMethodName(), _req->GetMethodNameLen());
LOG(debug, "invoke(server) init: '%s'", methodName.c_str());
}
+ req->SetReturnHandler(this); // Must be set prior to any access filter being invoked
if (_method == nullptr) {
if (!req->IsError()) { // may be BAD_REQUEST
req->SetError(FRTE_RPC_NO_SUCH_METHOD);
@@ -60,8 +61,11 @@ FRT_RPCInvoker::FRT_RPCInvoker(FRT_Supervisor *supervisor,
req->GetParamSpec()))
{
req->SetError(FRTE_RPC_WRONG_PARAMS);
+ } else if (_method->GetRequestAccessFilter() &&
+ !_method->GetRequestAccessFilter()->allow(*req))
+ {
+ req->SetError(FRTE_RPC_PERMISSION_DENIED);
}
- req->SetReturnHandler(this);
}
bool FRT_RPCInvoker::Invoke()
diff --git a/fnet/src/vespa/fnet/frt/reflection.cpp b/fnet/src/vespa/fnet/frt/reflection.cpp
index 211e681df94..af7fa069eb9 100644
--- a/fnet/src/vespa/fnet/frt/reflection.cpp
+++ b/fnet/src/vespa/fnet/frt/reflection.cpp
@@ -14,7 +14,8 @@ FRT_Method::FRT_Method(const char * name, const char * paramSpec, const char * r
_returnSpec(returnSpec),
_method(method),
_handler(handler),
- _doc()
+ _doc(),
+ _access_filter()
{
}
@@ -124,6 +125,7 @@ FRT_ReflectionBuilder::Flush()
}
_method->SetDocumentation(_values);
+ _method->SetRequestAccessFilter(std::move(_access_filter)); // May be nullptr
_method = nullptr;
_req->Reset();
}
@@ -142,7 +144,8 @@ FRT_ReflectionBuilder::FRT_ReflectionBuilder(FRT_Supervisor *supervisor)
_arg_name(nullptr),
_arg_desc(nullptr),
_ret_name(nullptr),
- _ret_desc(nullptr)
+ _ret_desc(nullptr),
+ _access_filter()
{
}
@@ -183,6 +186,7 @@ FRT_ReflectionBuilder::DefineMethod(const char *name,
_arg_desc = _values->AddStringArray(_argCnt);
_ret_name = _values->AddStringArray(_retCnt);
_ret_desc = _values->AddStringArray(_retCnt);
+ _access_filter.reset();
}
@@ -224,3 +228,12 @@ FRT_ReflectionBuilder::ReturnDesc(const char *name, const char *desc)
_values->SetString(&_ret_desc[_curRet], desc);
_curRet++;
}
+
+void
+FRT_ReflectionBuilder::RequestAccessFilter(std::unique_ptr<FRT_RequestAccessFilter> access_filter)
+{
+ if (_method == nullptr) {
+ return;
+ }
+ _access_filter = std::move(access_filter);
+}
diff --git a/fnet/src/vespa/fnet/frt/reflection.h b/fnet/src/vespa/fnet/frt/reflection.h
index 6267cafeeb1..3f833d053f1 100644
--- a/fnet/src/vespa/fnet/frt/reflection.h
+++ b/fnet/src/vespa/fnet/frt/reflection.h
@@ -3,6 +3,8 @@
#pragma once
#include "invokable.h"
+#include "request_access_filter.h"
+#include <memory>
#include <string>
#include <vector>
@@ -23,6 +25,7 @@ private:
FRT_METHOD_PT _method; // method pointer
FRT_Invokable *_handler; // method handler
std::vector<char> _doc; // method documentation
+ std::unique_ptr<FRT_RequestAccessFilter> _access_filter; // (optional) access filter
public:
FRT_Method(const FRT_Method &) = delete;
@@ -41,6 +44,10 @@ public:
const char *GetReturnSpec() { return _returnSpec.c_str(); }
FRT_METHOD_PT GetMethod() { return _method; }
FRT_Invokable *GetHandler() { return _handler; }
+ const FRT_RequestAccessFilter* GetRequestAccessFilter() const noexcept { return _access_filter.get(); }
+ void SetRequestAccessFilter(std::unique_ptr<FRT_RequestAccessFilter> access_filter) noexcept {
+ _access_filter = std::move(access_filter);
+ }
void SetDocumentation(FRT_Values *values);
void GetDocumentation(FRT_Values *values);
};
@@ -104,6 +111,7 @@ private:
FRT_StringValue *_arg_desc;
FRT_StringValue *_ret_name;
FRT_StringValue *_ret_desc;
+ std::unique_ptr<FRT_RequestAccessFilter> _access_filter;
FRT_ReflectionBuilder(const FRT_ReflectionBuilder &);
FRT_ReflectionBuilder &operator=(const FRT_ReflectionBuilder &);
@@ -122,5 +130,6 @@ public:
void MethodDesc(const char *desc);
void ParamDesc(const char *name, const char *desc);
void ReturnDesc(const char *name, const char *desc);
+ void RequestAccessFilter(std::unique_ptr<FRT_RequestAccessFilter> access_filter);
};
diff --git a/fnet/src/vespa/fnet/frt/request_access_filter.h b/fnet/src/vespa/fnet/frt/request_access_filter.h
new file mode 100644
index 00000000000..a02dca646f3
--- /dev/null
+++ b/fnet/src/vespa/fnet/frt/request_access_filter.h
@@ -0,0 +1,24 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+class FRT_RPCRequest;
+
+/**
+ * An RPC request access filter will, if provided during method registration, be
+ * invoked _prior_ to any RPC handler callback invocation for that method. It allows
+ * for implementing method-specific authorization handling, logging etc.
+ *
+ * Must be thread safe.
+ */
+class FRT_RequestAccessFilter {
+public:
+ virtual ~FRT_RequestAccessFilter() = default;
+
+ /**
+ * Iff true is returned, the request is allowed through and the RPC callback
+ * will be invoked as usual. If false, the request is immediately failed back
+ * to the caller with an error code.
+ */
+ [[nodiscard]] virtual bool allow(FRT_RPCRequest&) const noexcept = 0;
+};
diff --git a/fnet/src/vespa/fnet/frt/require_capability.cpp b/fnet/src/vespa/fnet/frt/require_capability.cpp
new file mode 100644
index 00000000000..5c64c2bb123
--- /dev/null
+++ b/fnet/src/vespa/fnet/frt/require_capability.cpp
@@ -0,0 +1,13 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "require_capability.h"
+#include "rpcrequest.h"
+#include <vespa/fnet/connection.h>
+#include <vespa/vespalib/net/connection_auth_context.h>
+
+bool
+FRT_RequireCapability::allow(FRT_RPCRequest& req) const noexcept
+{
+ const auto& auth_ctx = req.GetConnection()->auth_context();
+ return auth_ctx.capabilities().contains_all(_required_capabilities);
+}
diff --git a/fnet/src/vespa/fnet/frt/require_capability.h b/fnet/src/vespa/fnet/frt/require_capability.h
new file mode 100644
index 00000000000..c9eaf4937a8
--- /dev/null
+++ b/fnet/src/vespa/fnet/frt/require_capability.h
@@ -0,0 +1,21 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#pragma once
+
+#include "request_access_filter.h"
+#include <vespa/vespalib/net/tls/capability_set.h>
+
+/**
+ * An RPC access filter which verifies that a request is associated with an auth
+ * context that contains, at minimum, a given set of capabilities. If one or more
+ * required capabilities are missing, the request is denied.
+ */
+class FRT_RequireCapability final : public FRT_RequestAccessFilter {
+ vespalib::net::tls::CapabilitySet _required_capabilities;
+public:
+ explicit constexpr FRT_RequireCapability(vespalib::net::tls::CapabilitySet required_capabilities) noexcept
+ : _required_capabilities(required_capabilities)
+ {
+ }
+
+ bool allow(FRT_RPCRequest& req) const noexcept override;
+};