From a92222c82d37d69c6b96aca759a60e9809050bc9 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Fri, 20 Apr 2018 14:10:51 +0000 Subject: added a test that sends an open socket handle over an ipc connection --- vespalib/CMakeLists.txt | 1 + vespalib/src/tests/net/send_fd/CMakeLists.txt | 8 ++ vespalib/src/tests/net/send_fd/send_fd_test.cpp | 126 ++++++++++++++++++++++++ 3 files changed, 135 insertions(+) create mode 100644 vespalib/src/tests/net/send_fd/CMakeLists.txt create mode 100644 vespalib/src/tests/net/send_fd/send_fd_test.cpp diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 9ef6daea91c..40d0c539c86 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -51,6 +51,7 @@ vespa_define_module( src/tests/memory src/tests/net/async_resolver src/tests/net/selector + src/tests/net/send_fd src/tests/net/socket src/tests/net/socket_spec src/tests/objects/nbostream diff --git a/vespalib/src/tests/net/send_fd/CMakeLists.txt b/vespalib/src/tests/net/send_fd/CMakeLists.txt new file mode 100644 index 00000000000..3875e208793 --- /dev/null +++ b/vespalib/src/tests/net/send_fd/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_send_fd_test_app TEST + SOURCES + send_fd_test.cpp + DEPENDS + vespalib +) +vespa_add_test(NAME vespalib_send_fd_test_app COMMAND vespalib_send_fd_test_app) diff --git a/vespalib/src/tests/net/send_fd/send_fd_test.cpp b/vespalib/src/tests/net/send_fd/send_fd_test.cpp new file mode 100644 index 00000000000..077c7495b77 --- /dev/null +++ b/vespalib/src/tests/net/send_fd/send_fd_test.cpp @@ -0,0 +1,126 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace vespalib; + +vespalib::string read_bytes(SocketHandle &socket, size_t wanted_bytes) { + char tmp[64]; + vespalib::string result; + while (result.size() < wanted_bytes) { + size_t read_size = std::min(sizeof(tmp), wanted_bytes - result.size()); + ssize_t read_result = socket.read(tmp, read_size); + ASSERT_GREATER(read_result, 0); + result.append(tmp, read_result); + } + return result; +} + +void verify_socket_io(bool is_server, SocketHandle &socket) { + vespalib::string server_message = "hello, this is the server speaking"; + vespalib::string client_message = "please pick up, I need to talk to you"; + if(is_server) { + ssize_t written = socket.write(server_message.data(), server_message.size()); + EXPECT_EQUAL(written, ssize_t(server_message.size())); + vespalib::string read = read_bytes(socket, client_message.size()); + EXPECT_EQUAL(client_message, read); + } else { + ssize_t written = socket.write(client_message.data(), client_message.size()); + EXPECT_EQUAL(written, ssize_t(client_message.size())); + vespalib::string read = read_bytes(socket, server_message.size()); + EXPECT_EQUAL(server_message, read); + } +} + +SocketHandle connect(ServerSocket &server_socket) { + auto server = server_socket.address(); + auto spec = server.spec(); + fprintf(stderr, "connecting to '%s'\n", spec.c_str()); + return SocketSpec(spec).client_address().connect(); +} + +SocketHandle accept(ServerSocket &server_socket) { + auto server = server_socket.address(); + auto spec = server.spec(); + fprintf(stderr, "accepting from '%s'\n", spec.c_str()); + return server_socket.accept(); +} + +void send_fd(SocketHandle &socket, SocketHandle fd) { + fprintf(stderr, "sending fd: %d\n", fd.get()); + struct msghdr msg = {}; + char tag = '*'; + struct iovec data; + data.iov_base = &tag; + data.iov_len = 1; + char buf[CMSG_SPACE(sizeof(int))]; + msg.msg_iov = &data; + msg.msg_iovlen = 1; + msg.msg_control = buf; + msg.msg_controllen = sizeof(buf); + struct cmsghdr *hdr = CMSG_FIRSTHDR(&msg); + hdr->cmsg_level = SOL_SOCKET; + hdr->cmsg_type = SCM_RIGHTS; + hdr->cmsg_len = CMSG_LEN(sizeof(int)); + int *fd_dst = (int *) CMSG_DATA(hdr); + fd_dst[0] = fd.get(); + ssize_t res = sendmsg(socket.get(), &msg, 0); + ASSERT_EQUAL(res, 1); +} + +SocketHandle recv_fd(SocketHandle &socket) { + struct msghdr msg = {}; + char tag = '*'; + struct iovec data; + data.iov_base = &tag; + data.iov_len = 1; + char buf[CMSG_SPACE(sizeof(int))]; + msg.msg_iov = &data; + msg.msg_iovlen = 1; + msg.msg_control = buf; + msg.msg_controllen = sizeof(buf); + ssize_t res = recvmsg(socket.get(), &msg, 0); + ASSERT_EQUAL(res, 1); + struct cmsghdr *hdr = CMSG_FIRSTHDR(&msg); + bool type_ok = ((hdr->cmsg_level == SOL_SOCKET) && + (hdr->cmsg_type == SCM_RIGHTS)); + ASSERT_TRUE(type_ok); + int *fd_src = (int *) CMSG_DATA(hdr); + fprintf(stderr, "got fd: %d\n", fd_src[0]); + return SocketHandle(fd_src[0]); +} + +//----------------------------------------------------------------------------- + +TEST_MT_FFF("require that an open socket (handle) can be passed over a unix domain socket", 3, + ServerSocket("tcp/0"), ServerSocket("ipc/file:my_socket"), TimeBomb(60)) +{ + if (thread_id == 0) { // server + SocketHandle socket = accept(f1); + TEST_DO(verify_socket_io(true, socket)); // server side + TEST_BARRIER(); + } else if (thread_id == 1) { // proxy + SocketHandle server_socket = connect(f1); + SocketHandle client_socket = accept(f2); + send_fd(client_socket, std::move(server_socket)); + TEST_BARRIER(); + } else { // client + SocketHandle proxy_socket = connect(f2); + SocketHandle socket = recv_fd(proxy_socket); + TEST_DO(verify_socket_io(false, socket)); // client side + TEST_BARRIER(); + } +} + +TEST_MAIN() { TEST_RUN_ALL(); } -- cgit v1.2.3