diff options
author | HÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com> | 2023-01-13 16:12:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-13 16:12:08 +0100 |
commit | 316cdcbf3897fcf00c653eb7f6ea6ef5c39e9642 (patch) | |
tree | e817a5aaf089d6fc04891ac81630b2b65872a757 /vespalib | |
parent | 6ea555de57ad11bae44e7f9abca4d2b06d3863ae (diff) | |
parent | b0280f472f6074ed34bd1da4fc1c6b6aa517a765 (diff) |
Merge pull request #25516 from vespa-engine/havardpe/async-crypto-socket
async crypto socket proof of concept
Diffstat (limited to 'vespalib')
4 files changed, 354 insertions, 25 deletions
diff --git a/vespalib/src/tests/coro/async_io/async_io_test.cpp b/vespalib/src/tests/coro/async_io/async_io_test.cpp index a506a5dd0d4..f5098e30086 100644 --- a/vespalib/src/tests/coro/async_io/async_io_test.cpp +++ b/vespalib/src/tests/coro/async_io/async_io_test.cpp @@ -4,14 +4,22 @@ #include <vespa/vespalib/coro/detached.h> #include <vespa/vespalib/coro/completion.h> #include <vespa/vespalib/coro/async_io.h> +#include <vespa/vespalib/coro/async_crypto_socket.h> #include <vespa/vespalib/net/socket_spec.h> #include <vespa/vespalib/net/server_socket.h> #include <vespa/vespalib/net/socket_handle.h> #include <vespa/vespalib/net/socket_address.h> +#include <vespa/vespalib/net/crypto_engine.h> +#include <vespa/vespalib/util/require.h> +#include <vespa/vespalib/util/classname.h> +#include <vespa/vespalib/test/make_tls_options_for_testing.h> +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> +#include <vespa/vespalib/net/tls/maybe_tls_crypto_engine.h> #include <vespa/vespalib/gtest/gtest.h> using namespace vespalib; using namespace vespalib::coro; +using namespace vespalib::test; Detached self_exiting_run_loop(AsyncIo::SP async) { for (size_t i = 0; co_await async->schedule(); ++i) { @@ -53,11 +61,11 @@ TEST(AsyncIoTest, shutdown_with_self_exiting_coroutine) { f2.wait(); } -Lazy<size_t> write_msg(AsyncIo &async, SocketHandle &socket, const vespalib::string &msg) { +Lazy<size_t> write_msg(AsyncCryptoSocket &socket, const vespalib::string &msg) { size_t written = 0; while (written < msg.size()) { size_t write_size = (msg.size() - written); - ssize_t write_result = co_await async.write(socket, msg.data() + written, write_size); + ssize_t write_result = co_await socket.write(msg.data() + written, write_size); if (write_result <= 0) { co_return written; } @@ -66,12 +74,12 @@ Lazy<size_t> write_msg(AsyncIo &async, SocketHandle &socket, const vespalib::str co_return written; } -Lazy<vespalib::string> read_msg(AsyncIo &async, SocketHandle &socket, size_t wanted_bytes) { +Lazy<vespalib::string> read_msg(AsyncCryptoSocket &socket, size_t wanted_bytes) { char tmp[64]; vespalib::string result; while (result.size() < wanted_bytes) { size_t read_size = std::min(sizeof(tmp), wanted_bytes - result.size()); - ssize_t read_result = co_await async.read(socket, tmp, read_size); + ssize_t read_result = co_await socket.read(tmp, read_size); if (read_result <= 0) { co_return result; } @@ -80,50 +88,78 @@ Lazy<vespalib::string> read_msg(AsyncIo &async, SocketHandle &socket, size_t wan co_return result; } -Work verify_socket_io(AsyncIo &async, SocketHandle &socket, bool is_server) { +Work verify_socket_io(AsyncCryptoSocket &socket, bool is_server) { vespalib::string server_message = "hello, this is the server speaking"; vespalib::string client_message = "please pick up, I need to talk to you"; if (is_server) { - vespalib::string read = co_await read_msg(async, socket, client_message.size()); + vespalib::string read = co_await read_msg(socket, client_message.size()); EXPECT_EQ(client_message, read); - size_t written = co_await write_msg(async, socket, server_message); + size_t written = co_await write_msg(socket, server_message); EXPECT_EQ(written, ssize_t(server_message.size())); } else { - size_t written = co_await write_msg(async, socket, client_message); + size_t written = co_await write_msg(socket, client_message); EXPECT_EQ(written, ssize_t(client_message.size())); - vespalib::string read = co_await read_msg(async, socket, server_message.size()); + vespalib::string read = co_await read_msg(socket, server_message.size()); EXPECT_EQ(server_message, read); } co_return Done{}; } -Work async_server(AsyncIo &async, ServerSocket &server_socket) { +Work async_server(AsyncIo &async, CryptoEngine &engine, ServerSocket &server_socket) { auto server_addr = server_socket.address(); auto server_spec = server_addr.spec(); fprintf(stderr, "listening at '%s' (fd = %d)\n", server_spec.c_str(), server_socket.get_fd()); - auto socket = co_await async.accept(server_socket); - fprintf(stderr, "server fd: %d\n", socket.get()); - co_return co_await verify_socket_io(async, socket, true); + auto raw_socket = co_await async.accept(server_socket); + fprintf(stderr, "server fd: %d\n", raw_socket.get()); + auto socket = co_await AsyncCryptoSocket::accept(async, engine, std::move(raw_socket)); + EXPECT_TRUE(socket); + REQUIRE(socket); + fprintf(stderr, "server socket type: %s\n", getClassName(*socket).c_str()); + co_return co_await verify_socket_io(*socket, true); } -Work async_client(AsyncIo &async, ServerSocket &server_socket) { +Work async_client(AsyncIo &async, CryptoEngine &engine, ServerSocket &server_socket) { auto server_addr = server_socket.address(); - auto server_spec = server_addr.spec(); - fprintf(stderr, "connecting to '%s'\n", server_spec.c_str()); - auto client_addr = SocketSpec(server_spec).client_address(); - auto socket = co_await async.connect(client_addr); - fprintf(stderr, "client fd: %d\n", socket.get()); - co_return co_await verify_socket_io(async, socket, false); + auto server_spec = SocketSpec(server_addr.spec()); + fprintf(stderr, "connecting to '%s'\n", server_spec.spec().c_str()); + auto client_addr = server_spec.client_address(); + auto raw_socket = co_await async.connect(client_addr); + fprintf(stderr, "client fd: %d\n", raw_socket.get()); + auto socket = co_await AsyncCryptoSocket::connect(async, engine, std::move(raw_socket), server_spec); + EXPECT_TRUE(socket); + REQUIRE(socket); + fprintf(stderr, "client socket type: %s\n", getClassName(*socket).c_str()); + co_return co_await verify_socket_io(*socket, false); } -TEST(AsyncIoTest, raw_socket_io) { +void verify_socket_io(CryptoEngine &engine) { ServerSocket server_socket("tcp/0"); server_socket.set_blocking(false); auto async = AsyncIo::create(); - auto f1 = make_future(async_server(async, server_socket)); - auto f2 = make_future(async_client(async, server_socket)); - f1.wait(); - f2.wait(); + auto f1 = make_future(async_server(async, engine, server_socket)); + auto f2 = make_future(async_client(async, engine, server_socket)); + (void) f1.get(); + (void) f2.get(); +} + +TEST(AsyncIoTest, raw_socket_io) { + NullCryptoEngine engine; + verify_socket_io(engine); +} + +TEST(AsyncIoTest, tls_socket_io) { + TlsCryptoEngine engine(make_tls_options_for_testing()); + verify_socket_io(engine); +} + +TEST(AsyncIoTest, maybe_tls_true_socket_io) { + MaybeTlsCryptoEngine engine(std::make_shared<TlsCryptoEngine>(make_tls_options_for_testing()), true); + verify_socket_io(engine); +} + +TEST(AsyncIoTest, maybe_tls_false_socket_io) { + MaybeTlsCryptoEngine engine(std::make_shared<TlsCryptoEngine>(make_tls_options_for_testing()), false); + verify_socket_io(engine); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/coro/CMakeLists.txt b/vespalib/src/vespa/vespalib/coro/CMakeLists.txt index 0fbb94e8255..8a7a0ade049 100644 --- a/vespalib/src/vespa/vespalib/coro/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/coro/CMakeLists.txt @@ -1,6 +1,7 @@ # Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(vespalib_vespalib_coro OBJECT SOURCES + async_crypto_socket.cpp async_io.cpp DEPENDS ) diff --git a/vespalib/src/vespa/vespalib/coro/async_crypto_socket.cpp b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.cpp new file mode 100644 index 00000000000..4f862b48690 --- /dev/null +++ b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.cpp @@ -0,0 +1,261 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "async_crypto_socket.h" +#include <vespa/vespalib/net/tls/protocol_snooping.h> +#include <vespa/vespalib/net/tls/tls_crypto_engine.h> +#include <vespa/vespalib/net/tls/crypto_codec.h> +#include <vespa/vespalib/data/smart_buffer.h> + +namespace vespalib::coro { + +namespace { + +using net::tls::CryptoCodec; +using net::tls::HandshakeResult; +using net::tls::EncodeResult; +using net::tls::DecodeResult; + +struct InvalidSocket : AsyncCryptoSocket { + Lazy<ssize_t> read(char *, size_t) override { co_return -EINVAL; } + Lazy<ssize_t> write(const char *, size_t) override { co_return -EINVAL; } +}; + +struct RawSocket : AsyncCryptoSocket { + AsyncIo::SP async; + SocketHandle handle; + RawSocket(AsyncIo &async_in, SocketHandle handle_in) + : async(async_in.shared_from_this()), handle(std::move(handle_in)) {} + Lazy<ssize_t> read(char *buf, size_t len) override { + return async->read(handle, buf, len); + } + Lazy<ssize_t> write(const char *buf, size_t len) override { + return async->write(handle, buf, len); + } +}; + +struct SnoopedRawSocket : AsyncCryptoSocket { + AsyncIo::SP async; + SocketHandle handle; + SmartBuffer data; + SnoopedRawSocket(AsyncIo &async_in, SocketHandle handle_in) + : async(async_in.shared_from_this()), handle(std::move(handle_in)), data(0) {} + void inject_data(const char *buf, size_t len) { + if (len > 0) { + auto dst = data.reserve(len); + memcpy(dst.data, buf, len); + data.commit(len); + } + } + Lazy<ssize_t> read_from_buffer(char *buf, size_t len) { + auto src = data.obtain(); + size_t frame = std::min(len, src.size); + if (frame > 0) { + memcpy(buf, src.data, frame); + data.evict(frame); + data.drop_if_empty(); + } + co_return frame; + } + Lazy<ssize_t> read(char *buf, size_t len) override { + if (data.empty()) { + return async->read(handle, buf, len); + } else { + return read_from_buffer(buf, len); + } + } + Lazy<ssize_t> write(const char *buf, size_t len) override { + return async->write(handle, buf, len); + } +}; + +struct TlsSocket : AsyncCryptoSocket { + AsyncIo::SP async; + SocketHandle handle; + std::unique_ptr<CryptoCodec> codec; + SmartBuffer app_input; + SmartBuffer enc_input; + SmartBuffer enc_output; + TlsSocket(AsyncIo &async_in, SocketHandle handle_in, std::unique_ptr<CryptoCodec> codec_in) + : async(async_in.shared_from_this()), handle(std::move(handle_in)), codec(std::move(codec_in)), + app_input(0), enc_input(0), enc_output(0) {} + void inject_enc_input(const char *buf, size_t len) { + if (len > 0) { + auto dst = enc_input.reserve(len); + memcpy(dst.data, buf, len); + enc_input.commit(len); + } + } + Lazy<bool> flush_enc_output() { + while (!enc_output.empty()) { + auto pending = enc_output.obtain(); + auto res = co_await async->write(handle, pending.data, pending.size); + if (res > 0) { + enc_output.evict(res); + } else { + co_return false; + } + } + co_return true; + } + Lazy<bool> fill_enc_input() { + auto dst = enc_input.reserve(codec->min_encode_buffer_size()); + ssize_t res = co_await async->read(handle, dst.data, dst.size); + if (res > 0) { + enc_input.commit(res); + co_return true; + } else { + co_return false; + } + } + Lazy<bool> handshake() { + for (;;) { + auto in = enc_input.obtain(); + auto out = enc_output.reserve(codec->min_encode_buffer_size()); + auto hs_res = codec->handshake(in.data, in.size, out.data, out.size); + enc_input.evict(hs_res.bytes_consumed); + enc_output.commit(hs_res.bytes_produced); + switch (hs_res.state) { + case ::vespalib::net::tls::HandshakeResult::State::Failed: co_return false; + case ::vespalib::net::tls::HandshakeResult::State::Done: co_return co_await flush_enc_output(); + case ::vespalib::net::tls::HandshakeResult::State::NeedsWork: + codec->do_handshake_work(); + break; + case ::vespalib::net::tls::HandshakeResult::State::NeedsMorePeerData: + bool flush_ok = co_await flush_enc_output(); + if (!flush_ok) { + co_return false; + } + bool fill_ok = co_await fill_enc_input(); + if (!fill_ok) { + co_return false; + } + } + } + } + Lazy<ssize_t> read(char *buf, size_t len) override { + while (app_input.empty()) { + auto src = enc_input.obtain(); + auto dst = app_input.reserve(codec->min_decode_buffer_size()); + auto res = codec->decode(src.data, src.size, dst.data, dst.size); + app_input.commit(res.bytes_produced); + enc_input.evict(res.bytes_consumed); + if (res.failed()) { + co_return -EIO; + } + if (res.closed()) { + co_return 0; + } + if (app_input.empty()) { + bool fill_ok = co_await fill_enc_input(); + if (!fill_ok) { + co_return -EIO; + } + } + } + auto src = app_input.obtain(); + size_t frame = std::min(len, src.size); + if (frame > 0) { + memcpy(buf, src.data, frame); + app_input.evict(frame); + } + co_return frame; + } + Lazy<ssize_t> write(const char *buf, size_t len) override { + auto dst = enc_output.reserve(codec->min_encode_buffer_size()); + auto res = codec->encode(buf, len, dst.data, dst.size); + if (res.failed) { + co_return -EIO; + } + enc_output.commit(res.bytes_produced); + bool flush_ok = co_await flush_enc_output(); + if (!flush_ok) { + co_return -EIO; + } + co_return res.bytes_consumed; + } +}; + +Lazy<AsyncCryptoSocket::UP> try_handshake(std::unique_ptr<TlsSocket> tls_socket) { + bool hs_ok = co_await tls_socket->handshake(); + if (hs_ok) { + co_return std::move(tls_socket); + } else { + co_return std::make_unique<InvalidSocket>(); + } +} + +Lazy<AsyncCryptoSocket::UP> accept_tls(AsyncIo &async, AbstractTlsCryptoEngine &crypto, SocketHandle handle) { + auto tls_codec = crypto.create_tls_server_crypto_codec(handle); + auto tls_socket = std::make_unique<TlsSocket>(async, std::move(handle), std::move(tls_codec)); + co_return co_await try_handshake(std::move(tls_socket)); +} + +Lazy<AsyncCryptoSocket::UP> accept_maybe_tls(AsyncIo &async, AbstractTlsCryptoEngine &crypto, SocketHandle handle) { + char buf[net::tls::snooping::min_header_bytes_to_observe()]; + memset(buf, 0, sizeof(buf)); + size_t snooped = 0; + while (snooped < sizeof(buf)) { + auto res = co_await async.read(handle, buf + snooped, sizeof(buf) - snooped); + if (res <= 0) { + co_return std::make_unique<InvalidSocket>(); + } + snooped += res; + } + if (net::tls::snooping::snoop_client_hello_header(buf) == net::tls::snooping::TlsSnoopingResult::ProbablyTls) { + auto tls_codec = crypto.create_tls_server_crypto_codec(handle); + auto tls_socket = std::make_unique<TlsSocket>(async, std::move(handle), std::move(tls_codec)); + tls_socket->inject_enc_input(buf, snooped); + co_return co_await try_handshake(std::move(tls_socket)); + } else { + auto plain_socket = std::make_unique<SnoopedRawSocket>(async, std::move(handle)); + plain_socket->inject_data(buf, snooped); + co_return std::move(plain_socket); + } +} + +Lazy<AsyncCryptoSocket::UP> connect_tls(AsyncIo &async, AbstractTlsCryptoEngine &crypto, SocketHandle handle, SocketSpec spec) { + auto tls_codec = crypto.create_tls_client_crypto_codec(handle, spec); + auto tls_socket = std::make_unique<TlsSocket>(async, std::move(handle), std::move(tls_codec)); + co_return co_await try_handshake(std::move(tls_socket)); +} + +} + +AsyncCryptoSocket::~AsyncCryptoSocket() = default; + +Lazy<AsyncCryptoSocket::UP> +AsyncCryptoSocket::accept(AsyncIo &async, CryptoEngine &crypto, + SocketHandle handle) +{ + if (dynamic_cast<NullCryptoEngine*>(&crypto)) { + co_return std::make_unique<RawSocket>(async, std::move(handle)); + } + if (auto *tls_engine = dynamic_cast<AbstractTlsCryptoEngine*>(&crypto)) { + if (tls_engine->always_use_tls_when_server()) { + co_return co_await accept_tls(async, *tls_engine, std::move(handle)); + } else { + co_return co_await accept_maybe_tls(async, *tls_engine, std::move(handle)); + } + } + co_return std::make_unique<InvalidSocket>(); +} + +Lazy<AsyncCryptoSocket::UP> +AsyncCryptoSocket::connect(AsyncIo &async, CryptoEngine &crypto, + SocketHandle handle, SocketSpec spec) +{ + if (dynamic_cast<NullCryptoEngine*>(&crypto)) { + (void) spec; // no SNI for plaintext sockets + co_return std::make_unique<RawSocket>(async, std::move(handle)); + } + if (auto *tls_engine = dynamic_cast<AbstractTlsCryptoEngine*>(&crypto)) { + if (tls_engine->use_tls_when_client()) { + co_return co_await connect_tls(async, *tls_engine, std::move(handle), spec); + } else { + co_return std::make_unique<RawSocket>(async, std::move(handle)); + } + } + co_return std::make_unique<InvalidSocket>(); +} + +} diff --git a/vespalib/src/vespa/vespalib/coro/async_crypto_socket.h b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.h new file mode 100644 index 00000000000..2fcf4efe9b4 --- /dev/null +++ b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.h @@ -0,0 +1,31 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "lazy.h" +#include "async_io.h" + +#include <vespa/vespalib/net/socket_spec.h> +#include <vespa/vespalib/net/socket_handle.h> +#include <vespa/vespalib/net/crypto_engine.h> + +#include <memory> + +namespace vespalib::coro { + +// A socket endpoint supporting async read/write with encryption + +struct AsyncCryptoSocket { + using UP = std::unique_ptr<AsyncCryptoSocket>; + + virtual Lazy<ssize_t> read(char *buf, size_t len) = 0; + virtual Lazy<ssize_t> write(const char *buf, size_t len) = 0; + virtual ~AsyncCryptoSocket(); + + static Lazy<AsyncCryptoSocket::UP> accept(AsyncIo &async, CryptoEngine &crypto, + SocketHandle handle); + static Lazy<AsyncCryptoSocket::UP> connect(AsyncIo &async, CryptoEngine &crypto, + SocketHandle handle, SocketSpec spec); +}; + +} |