summaryrefslogtreecommitdiffstats
path: root/vespalib
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com>2023-01-13 16:12:08 +0100
committerGitHub <noreply@github.com>2023-01-13 16:12:08 +0100
commit316cdcbf3897fcf00c653eb7f6ea6ef5c39e9642 (patch)
treee817a5aaf089d6fc04891ac81630b2b65872a757 /vespalib
parent6ea555de57ad11bae44e7f9abca4d2b06d3863ae (diff)
parentb0280f472f6074ed34bd1da4fc1c6b6aa517a765 (diff)
Merge pull request #25516 from vespa-engine/havardpe/async-crypto-socket
async crypto socket proof of concept
Diffstat (limited to 'vespalib')
-rw-r--r--vespalib/src/tests/coro/async_io/async_io_test.cpp86
-rw-r--r--vespalib/src/vespa/vespalib/coro/CMakeLists.txt1
-rw-r--r--vespalib/src/vespa/vespalib/coro/async_crypto_socket.cpp261
-rw-r--r--vespalib/src/vespa/vespalib/coro/async_crypto_socket.h31
4 files changed, 354 insertions, 25 deletions
diff --git a/vespalib/src/tests/coro/async_io/async_io_test.cpp b/vespalib/src/tests/coro/async_io/async_io_test.cpp
index a506a5dd0d4..f5098e30086 100644
--- a/vespalib/src/tests/coro/async_io/async_io_test.cpp
+++ b/vespalib/src/tests/coro/async_io/async_io_test.cpp
@@ -4,14 +4,22 @@
#include <vespa/vespalib/coro/detached.h>
#include <vespa/vespalib/coro/completion.h>
#include <vespa/vespalib/coro/async_io.h>
+#include <vespa/vespalib/coro/async_crypto_socket.h>
#include <vespa/vespalib/net/socket_spec.h>
#include <vespa/vespalib/net/server_socket.h>
#include <vespa/vespalib/net/socket_handle.h>
#include <vespa/vespalib/net/socket_address.h>
+#include <vespa/vespalib/net/crypto_engine.h>
+#include <vespa/vespalib/util/require.h>
+#include <vespa/vespalib/util/classname.h>
+#include <vespa/vespalib/test/make_tls_options_for_testing.h>
+#include <vespa/vespalib/net/tls/tls_crypto_engine.h>
+#include <vespa/vespalib/net/tls/maybe_tls_crypto_engine.h>
#include <vespa/vespalib/gtest/gtest.h>
using namespace vespalib;
using namespace vespalib::coro;
+using namespace vespalib::test;
Detached self_exiting_run_loop(AsyncIo::SP async) {
for (size_t i = 0; co_await async->schedule(); ++i) {
@@ -53,11 +61,11 @@ TEST(AsyncIoTest, shutdown_with_self_exiting_coroutine) {
f2.wait();
}
-Lazy<size_t> write_msg(AsyncIo &async, SocketHandle &socket, const vespalib::string &msg) {
+Lazy<size_t> write_msg(AsyncCryptoSocket &socket, const vespalib::string &msg) {
size_t written = 0;
while (written < msg.size()) {
size_t write_size = (msg.size() - written);
- ssize_t write_result = co_await async.write(socket, msg.data() + written, write_size);
+ ssize_t write_result = co_await socket.write(msg.data() + written, write_size);
if (write_result <= 0) {
co_return written;
}
@@ -66,12 +74,12 @@ Lazy<size_t> write_msg(AsyncIo &async, SocketHandle &socket, const vespalib::str
co_return written;
}
-Lazy<vespalib::string> read_msg(AsyncIo &async, SocketHandle &socket, size_t wanted_bytes) {
+Lazy<vespalib::string> read_msg(AsyncCryptoSocket &socket, size_t wanted_bytes) {
char tmp[64];
vespalib::string result;
while (result.size() < wanted_bytes) {
size_t read_size = std::min(sizeof(tmp), wanted_bytes - result.size());
- ssize_t read_result = co_await async.read(socket, tmp, read_size);
+ ssize_t read_result = co_await socket.read(tmp, read_size);
if (read_result <= 0) {
co_return result;
}
@@ -80,50 +88,78 @@ Lazy<vespalib::string> read_msg(AsyncIo &async, SocketHandle &socket, size_t wan
co_return result;
}
-Work verify_socket_io(AsyncIo &async, SocketHandle &socket, bool is_server) {
+Work verify_socket_io(AsyncCryptoSocket &socket, bool is_server) {
vespalib::string server_message = "hello, this is the server speaking";
vespalib::string client_message = "please pick up, I need to talk to you";
if (is_server) {
- vespalib::string read = co_await read_msg(async, socket, client_message.size());
+ vespalib::string read = co_await read_msg(socket, client_message.size());
EXPECT_EQ(client_message, read);
- size_t written = co_await write_msg(async, socket, server_message);
+ size_t written = co_await write_msg(socket, server_message);
EXPECT_EQ(written, ssize_t(server_message.size()));
} else {
- size_t written = co_await write_msg(async, socket, client_message);
+ size_t written = co_await write_msg(socket, client_message);
EXPECT_EQ(written, ssize_t(client_message.size()));
- vespalib::string read = co_await read_msg(async, socket, server_message.size());
+ vespalib::string read = co_await read_msg(socket, server_message.size());
EXPECT_EQ(server_message, read);
}
co_return Done{};
}
-Work async_server(AsyncIo &async, ServerSocket &server_socket) {
+Work async_server(AsyncIo &async, CryptoEngine &engine, ServerSocket &server_socket) {
auto server_addr = server_socket.address();
auto server_spec = server_addr.spec();
fprintf(stderr, "listening at '%s' (fd = %d)\n", server_spec.c_str(), server_socket.get_fd());
- auto socket = co_await async.accept(server_socket);
- fprintf(stderr, "server fd: %d\n", socket.get());
- co_return co_await verify_socket_io(async, socket, true);
+ auto raw_socket = co_await async.accept(server_socket);
+ fprintf(stderr, "server fd: %d\n", raw_socket.get());
+ auto socket = co_await AsyncCryptoSocket::accept(async, engine, std::move(raw_socket));
+ EXPECT_TRUE(socket);
+ REQUIRE(socket);
+ fprintf(stderr, "server socket type: %s\n", getClassName(*socket).c_str());
+ co_return co_await verify_socket_io(*socket, true);
}
-Work async_client(AsyncIo &async, ServerSocket &server_socket) {
+Work async_client(AsyncIo &async, CryptoEngine &engine, ServerSocket &server_socket) {
auto server_addr = server_socket.address();
- auto server_spec = server_addr.spec();
- fprintf(stderr, "connecting to '%s'\n", server_spec.c_str());
- auto client_addr = SocketSpec(server_spec).client_address();
- auto socket = co_await async.connect(client_addr);
- fprintf(stderr, "client fd: %d\n", socket.get());
- co_return co_await verify_socket_io(async, socket, false);
+ auto server_spec = SocketSpec(server_addr.spec());
+ fprintf(stderr, "connecting to '%s'\n", server_spec.spec().c_str());
+ auto client_addr = server_spec.client_address();
+ auto raw_socket = co_await async.connect(client_addr);
+ fprintf(stderr, "client fd: %d\n", raw_socket.get());
+ auto socket = co_await AsyncCryptoSocket::connect(async, engine, std::move(raw_socket), server_spec);
+ EXPECT_TRUE(socket);
+ REQUIRE(socket);
+ fprintf(stderr, "client socket type: %s\n", getClassName(*socket).c_str());
+ co_return co_await verify_socket_io(*socket, false);
}
-TEST(AsyncIoTest, raw_socket_io) {
+void verify_socket_io(CryptoEngine &engine) {
ServerSocket server_socket("tcp/0");
server_socket.set_blocking(false);
auto async = AsyncIo::create();
- auto f1 = make_future(async_server(async, server_socket));
- auto f2 = make_future(async_client(async, server_socket));
- f1.wait();
- f2.wait();
+ auto f1 = make_future(async_server(async, engine, server_socket));
+ auto f2 = make_future(async_client(async, engine, server_socket));
+ (void) f1.get();
+ (void) f2.get();
+}
+
+TEST(AsyncIoTest, raw_socket_io) {
+ NullCryptoEngine engine;
+ verify_socket_io(engine);
+}
+
+TEST(AsyncIoTest, tls_socket_io) {
+ TlsCryptoEngine engine(make_tls_options_for_testing());
+ verify_socket_io(engine);
+}
+
+TEST(AsyncIoTest, maybe_tls_true_socket_io) {
+ MaybeTlsCryptoEngine engine(std::make_shared<TlsCryptoEngine>(make_tls_options_for_testing()), true);
+ verify_socket_io(engine);
+}
+
+TEST(AsyncIoTest, maybe_tls_false_socket_io) {
+ MaybeTlsCryptoEngine engine(std::make_shared<TlsCryptoEngine>(make_tls_options_for_testing()), false);
+ verify_socket_io(engine);
}
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/vespa/vespalib/coro/CMakeLists.txt b/vespalib/src/vespa/vespalib/coro/CMakeLists.txt
index 0fbb94e8255..8a7a0ade049 100644
--- a/vespalib/src/vespa/vespalib/coro/CMakeLists.txt
+++ b/vespalib/src/vespa/vespalib/coro/CMakeLists.txt
@@ -1,6 +1,7 @@
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_library(vespalib_vespalib_coro OBJECT
SOURCES
+ async_crypto_socket.cpp
async_io.cpp
DEPENDS
)
diff --git a/vespalib/src/vespa/vespalib/coro/async_crypto_socket.cpp b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.cpp
new file mode 100644
index 00000000000..4f862b48690
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.cpp
@@ -0,0 +1,261 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "async_crypto_socket.h"
+#include <vespa/vespalib/net/tls/protocol_snooping.h>
+#include <vespa/vespalib/net/tls/tls_crypto_engine.h>
+#include <vespa/vespalib/net/tls/crypto_codec.h>
+#include <vespa/vespalib/data/smart_buffer.h>
+
+namespace vespalib::coro {
+
+namespace {
+
+using net::tls::CryptoCodec;
+using net::tls::HandshakeResult;
+using net::tls::EncodeResult;
+using net::tls::DecodeResult;
+
+struct InvalidSocket : AsyncCryptoSocket {
+ Lazy<ssize_t> read(char *, size_t) override { co_return -EINVAL; }
+ Lazy<ssize_t> write(const char *, size_t) override { co_return -EINVAL; }
+};
+
+struct RawSocket : AsyncCryptoSocket {
+ AsyncIo::SP async;
+ SocketHandle handle;
+ RawSocket(AsyncIo &async_in, SocketHandle handle_in)
+ : async(async_in.shared_from_this()), handle(std::move(handle_in)) {}
+ Lazy<ssize_t> read(char *buf, size_t len) override {
+ return async->read(handle, buf, len);
+ }
+ Lazy<ssize_t> write(const char *buf, size_t len) override {
+ return async->write(handle, buf, len);
+ }
+};
+
+struct SnoopedRawSocket : AsyncCryptoSocket {
+ AsyncIo::SP async;
+ SocketHandle handle;
+ SmartBuffer data;
+ SnoopedRawSocket(AsyncIo &async_in, SocketHandle handle_in)
+ : async(async_in.shared_from_this()), handle(std::move(handle_in)), data(0) {}
+ void inject_data(const char *buf, size_t len) {
+ if (len > 0) {
+ auto dst = data.reserve(len);
+ memcpy(dst.data, buf, len);
+ data.commit(len);
+ }
+ }
+ Lazy<ssize_t> read_from_buffer(char *buf, size_t len) {
+ auto src = data.obtain();
+ size_t frame = std::min(len, src.size);
+ if (frame > 0) {
+ memcpy(buf, src.data, frame);
+ data.evict(frame);
+ data.drop_if_empty();
+ }
+ co_return frame;
+ }
+ Lazy<ssize_t> read(char *buf, size_t len) override {
+ if (data.empty()) {
+ return async->read(handle, buf, len);
+ } else {
+ return read_from_buffer(buf, len);
+ }
+ }
+ Lazy<ssize_t> write(const char *buf, size_t len) override {
+ return async->write(handle, buf, len);
+ }
+};
+
+struct TlsSocket : AsyncCryptoSocket {
+ AsyncIo::SP async;
+ SocketHandle handle;
+ std::unique_ptr<CryptoCodec> codec;
+ SmartBuffer app_input;
+ SmartBuffer enc_input;
+ SmartBuffer enc_output;
+ TlsSocket(AsyncIo &async_in, SocketHandle handle_in, std::unique_ptr<CryptoCodec> codec_in)
+ : async(async_in.shared_from_this()), handle(std::move(handle_in)), codec(std::move(codec_in)),
+ app_input(0), enc_input(0), enc_output(0) {}
+ void inject_enc_input(const char *buf, size_t len) {
+ if (len > 0) {
+ auto dst = enc_input.reserve(len);
+ memcpy(dst.data, buf, len);
+ enc_input.commit(len);
+ }
+ }
+ Lazy<bool> flush_enc_output() {
+ while (!enc_output.empty()) {
+ auto pending = enc_output.obtain();
+ auto res = co_await async->write(handle, pending.data, pending.size);
+ if (res > 0) {
+ enc_output.evict(res);
+ } else {
+ co_return false;
+ }
+ }
+ co_return true;
+ }
+ Lazy<bool> fill_enc_input() {
+ auto dst = enc_input.reserve(codec->min_encode_buffer_size());
+ ssize_t res = co_await async->read(handle, dst.data, dst.size);
+ if (res > 0) {
+ enc_input.commit(res);
+ co_return true;
+ } else {
+ co_return false;
+ }
+ }
+ Lazy<bool> handshake() {
+ for (;;) {
+ auto in = enc_input.obtain();
+ auto out = enc_output.reserve(codec->min_encode_buffer_size());
+ auto hs_res = codec->handshake(in.data, in.size, out.data, out.size);
+ enc_input.evict(hs_res.bytes_consumed);
+ enc_output.commit(hs_res.bytes_produced);
+ switch (hs_res.state) {
+ case ::vespalib::net::tls::HandshakeResult::State::Failed: co_return false;
+ case ::vespalib::net::tls::HandshakeResult::State::Done: co_return co_await flush_enc_output();
+ case ::vespalib::net::tls::HandshakeResult::State::NeedsWork:
+ codec->do_handshake_work();
+ break;
+ case ::vespalib::net::tls::HandshakeResult::State::NeedsMorePeerData:
+ bool flush_ok = co_await flush_enc_output();
+ if (!flush_ok) {
+ co_return false;
+ }
+ bool fill_ok = co_await fill_enc_input();
+ if (!fill_ok) {
+ co_return false;
+ }
+ }
+ }
+ }
+ Lazy<ssize_t> read(char *buf, size_t len) override {
+ while (app_input.empty()) {
+ auto src = enc_input.obtain();
+ auto dst = app_input.reserve(codec->min_decode_buffer_size());
+ auto res = codec->decode(src.data, src.size, dst.data, dst.size);
+ app_input.commit(res.bytes_produced);
+ enc_input.evict(res.bytes_consumed);
+ if (res.failed()) {
+ co_return -EIO;
+ }
+ if (res.closed()) {
+ co_return 0;
+ }
+ if (app_input.empty()) {
+ bool fill_ok = co_await fill_enc_input();
+ if (!fill_ok) {
+ co_return -EIO;
+ }
+ }
+ }
+ auto src = app_input.obtain();
+ size_t frame = std::min(len, src.size);
+ if (frame > 0) {
+ memcpy(buf, src.data, frame);
+ app_input.evict(frame);
+ }
+ co_return frame;
+ }
+ Lazy<ssize_t> write(const char *buf, size_t len) override {
+ auto dst = enc_output.reserve(codec->min_encode_buffer_size());
+ auto res = codec->encode(buf, len, dst.data, dst.size);
+ if (res.failed) {
+ co_return -EIO;
+ }
+ enc_output.commit(res.bytes_produced);
+ bool flush_ok = co_await flush_enc_output();
+ if (!flush_ok) {
+ co_return -EIO;
+ }
+ co_return res.bytes_consumed;
+ }
+};
+
+Lazy<AsyncCryptoSocket::UP> try_handshake(std::unique_ptr<TlsSocket> tls_socket) {
+ bool hs_ok = co_await tls_socket->handshake();
+ if (hs_ok) {
+ co_return std::move(tls_socket);
+ } else {
+ co_return std::make_unique<InvalidSocket>();
+ }
+}
+
+Lazy<AsyncCryptoSocket::UP> accept_tls(AsyncIo &async, AbstractTlsCryptoEngine &crypto, SocketHandle handle) {
+ auto tls_codec = crypto.create_tls_server_crypto_codec(handle);
+ auto tls_socket = std::make_unique<TlsSocket>(async, std::move(handle), std::move(tls_codec));
+ co_return co_await try_handshake(std::move(tls_socket));
+}
+
+Lazy<AsyncCryptoSocket::UP> accept_maybe_tls(AsyncIo &async, AbstractTlsCryptoEngine &crypto, SocketHandle handle) {
+ char buf[net::tls::snooping::min_header_bytes_to_observe()];
+ memset(buf, 0, sizeof(buf));
+ size_t snooped = 0;
+ while (snooped < sizeof(buf)) {
+ auto res = co_await async.read(handle, buf + snooped, sizeof(buf) - snooped);
+ if (res <= 0) {
+ co_return std::make_unique<InvalidSocket>();
+ }
+ snooped += res;
+ }
+ if (net::tls::snooping::snoop_client_hello_header(buf) == net::tls::snooping::TlsSnoopingResult::ProbablyTls) {
+ auto tls_codec = crypto.create_tls_server_crypto_codec(handle);
+ auto tls_socket = std::make_unique<TlsSocket>(async, std::move(handle), std::move(tls_codec));
+ tls_socket->inject_enc_input(buf, snooped);
+ co_return co_await try_handshake(std::move(tls_socket));
+ } else {
+ auto plain_socket = std::make_unique<SnoopedRawSocket>(async, std::move(handle));
+ plain_socket->inject_data(buf, snooped);
+ co_return std::move(plain_socket);
+ }
+}
+
+Lazy<AsyncCryptoSocket::UP> connect_tls(AsyncIo &async, AbstractTlsCryptoEngine &crypto, SocketHandle handle, SocketSpec spec) {
+ auto tls_codec = crypto.create_tls_client_crypto_codec(handle, spec);
+ auto tls_socket = std::make_unique<TlsSocket>(async, std::move(handle), std::move(tls_codec));
+ co_return co_await try_handshake(std::move(tls_socket));
+}
+
+}
+
+AsyncCryptoSocket::~AsyncCryptoSocket() = default;
+
+Lazy<AsyncCryptoSocket::UP>
+AsyncCryptoSocket::accept(AsyncIo &async, CryptoEngine &crypto,
+ SocketHandle handle)
+{
+ if (dynamic_cast<NullCryptoEngine*>(&crypto)) {
+ co_return std::make_unique<RawSocket>(async, std::move(handle));
+ }
+ if (auto *tls_engine = dynamic_cast<AbstractTlsCryptoEngine*>(&crypto)) {
+ if (tls_engine->always_use_tls_when_server()) {
+ co_return co_await accept_tls(async, *tls_engine, std::move(handle));
+ } else {
+ co_return co_await accept_maybe_tls(async, *tls_engine, std::move(handle));
+ }
+ }
+ co_return std::make_unique<InvalidSocket>();
+}
+
+Lazy<AsyncCryptoSocket::UP>
+AsyncCryptoSocket::connect(AsyncIo &async, CryptoEngine &crypto,
+ SocketHandle handle, SocketSpec spec)
+{
+ if (dynamic_cast<NullCryptoEngine*>(&crypto)) {
+ (void) spec; // no SNI for plaintext sockets
+ co_return std::make_unique<RawSocket>(async, std::move(handle));
+ }
+ if (auto *tls_engine = dynamic_cast<AbstractTlsCryptoEngine*>(&crypto)) {
+ if (tls_engine->use_tls_when_client()) {
+ co_return co_await connect_tls(async, *tls_engine, std::move(handle), spec);
+ } else {
+ co_return std::make_unique<RawSocket>(async, std::move(handle));
+ }
+ }
+ co_return std::make_unique<InvalidSocket>();
+}
+
+}
diff --git a/vespalib/src/vespa/vespalib/coro/async_crypto_socket.h b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.h
new file mode 100644
index 00000000000..2fcf4efe9b4
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/coro/async_crypto_socket.h
@@ -0,0 +1,31 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "lazy.h"
+#include "async_io.h"
+
+#include <vespa/vespalib/net/socket_spec.h>
+#include <vespa/vespalib/net/socket_handle.h>
+#include <vespa/vespalib/net/crypto_engine.h>
+
+#include <memory>
+
+namespace vespalib::coro {
+
+// A socket endpoint supporting async read/write with encryption
+
+struct AsyncCryptoSocket {
+ using UP = std::unique_ptr<AsyncCryptoSocket>;
+
+ virtual Lazy<ssize_t> read(char *buf, size_t len) = 0;
+ virtual Lazy<ssize_t> write(const char *buf, size_t len) = 0;
+ virtual ~AsyncCryptoSocket();
+
+ static Lazy<AsyncCryptoSocket::UP> accept(AsyncIo &async, CryptoEngine &crypto,
+ SocketHandle handle);
+ static Lazy<AsyncCryptoSocket::UP> connect(AsyncIo &async, CryptoEngine &crypto,
+ SocketHandle handle, SocketSpec spec);
+};
+
+}