diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-05-09 13:46:51 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-05-13 10:05:24 +0000 |
commit | 032ec0ed6f65a355c5c6402f2e2daae1f6ea5b00 (patch) | |
tree | 4a56cf3323bee758b962d723f8fa0b6c74ead7ed | |
parent | c9f89a485d3dee9ddffb5107b31bf0bae91b18d4 (diff) |
multi-threaded transport for JRT
19 files changed, 604 insertions, 482 deletions
diff --git a/jrt/src/com/yahoo/jrt/Acceptor.java b/jrt/src/com/yahoo/jrt/Acceptor.java index 3da978fb90e..dd4b5a72a4f 100644 --- a/jrt/src/com/yahoo/jrt/Acceptor.java +++ b/jrt/src/com/yahoo/jrt/Acceptor.java @@ -100,8 +100,9 @@ public class Acceptor { private void run() { while (serverChannel.isOpen()) { try { - parent.addConnection(new Connection(parent, owner, serverChannel.accept())); - parent.sync(); + TransportThread tt = parent.selectThread(); + tt.addConnection(new Connection(tt, owner, serverChannel.accept())); + tt.sync(); } catch (ClosedChannelException ignore) { } catch (Exception e) { log.log(Level.WARNING, "Error accepting connection", e); diff --git a/jrt/src/com/yahoo/jrt/Connection.java b/jrt/src/com/yahoo/jrt/Connection.java index e6772f94bb8..9a3a7cd083f 100644 --- a/jrt/src/com/yahoo/jrt/Connection.java +++ b/jrt/src/com/yahoo/jrt/Connection.java @@ -41,7 +41,7 @@ class Connection extends Target { private int activeReqs = 0; private int writeWork = 0; private boolean pendingHandshakeWork = false; - private Transport parent; + private final TransportThread parent; private Supervisor owner; private Spec spec; private CryptoSocket socket; @@ -88,17 +88,17 @@ class Connection extends Target { } } - public Connection(Transport parent, Supervisor owner, + public Connection(TransportThread parent, Supervisor owner, SocketChannel channel) { this.parent = parent; this.owner = owner; - this.socket = parent.createCryptoSocket(channel, true); + this.socket = parent.transport().createCryptoSocket(channel, true); server = true; owner.sessionInit(this); } - public Connection(Transport parent, Supervisor owner, Spec spec, Object context) { + public Connection(TransportThread parent, Supervisor owner, Spec spec, Object context) { super(context); this.parent = parent; this.owner = owner; @@ -115,7 +115,7 @@ class Connection extends Target { maxOutputSize = bytes; } - public Transport transport() { + public TransportThread transportThread() { return parent; } @@ -170,7 +170,7 @@ class Connection extends Target { return this; } try { - socket = parent.createCryptoSocket(SocketChannel.open(spec.address()), false); + socket = parent.transport().createCryptoSocket(SocketChannel.open(spec.address()), false); } catch (Exception e) { setLostReason(e); } @@ -242,7 +242,7 @@ class Connection extends Target { disableRead(); disableWrite(); pendingHandshakeWork = true; - parent.doHandshakeWork(this); + parent.transport().doHandshakeWork(this); break; } } diff --git a/jrt/src/com/yahoo/jrt/Connector.java b/jrt/src/com/yahoo/jrt/Connector.java index a4cbd07d3f8..4c83a2884bd 100644 --- a/jrt/src/com/yahoo/jrt/Connector.java +++ b/jrt/src/com/yahoo/jrt/Connector.java @@ -28,7 +28,7 @@ class Connector { public void connectLater(Connection c) { if ( ! connectQueue.enqueue(c)) { - parent.addConnection(c); + c.transportThread().addConnection(c); } } @@ -36,7 +36,7 @@ class Connector { try { while (true) { Connection conn = (Connection) connectQueue.dequeue(); - parent.addConnection(conn.connect()); + conn.transportThread().addConnection(conn.connect()); } } catch (EndOfQueueException e) {} synchronized (this) { diff --git a/jrt/src/com/yahoo/jrt/InvocationClient.java b/jrt/src/com/yahoo/jrt/InvocationClient.java index 0b01ea0935b..71b68e5a397 100644 --- a/jrt/src/com/yahoo/jrt/InvocationClient.java +++ b/jrt/src/com/yahoo/jrt/InvocationClient.java @@ -21,7 +21,7 @@ class InvocationClient implements ReplyHandler, Runnable { req.clientHandler(this); this.replyKey = conn.allocateKey(); - this.timeoutTask = conn.transport().createTask(this); + this.timeoutTask = conn.transportThread().createTask(this); } public void invoke() { diff --git a/jrt/src/com/yahoo/jrt/MandatoryMethods.java b/jrt/src/com/yahoo/jrt/MandatoryMethods.java index e528dc8197c..1176884eed5 100644 --- a/jrt/src/com/yahoo/jrt/MandatoryMethods.java +++ b/jrt/src/com/yahoo/jrt/MandatoryMethods.java @@ -2,6 +2,7 @@ package com.yahoo.jrt; +import java.util.Collection; import java.util.Iterator; @@ -47,15 +48,14 @@ class MandatoryMethods { } public void getMethodList(Request req) { - int cnt = parent.methodMap().size(); + Collection<Method> methods = parent.methodMap().values(); + int cnt = methods.size(); String[] ret0_names = new String[cnt]; String[] ret1_params = new String[cnt]; String[] ret2_return = new String[cnt]; int i = 0; - Iterator<Method> itr = parent.methodMap().values().iterator(); - while (itr.hasNext()) { - Method m = itr.next(); + for (Method m: methods) { ret0_names[i] = m.name(); ret1_params[i] = m.paramTypes(); ret2_return[i] = m.returnTypes(); diff --git a/jrt/src/com/yahoo/jrt/Supervisor.java b/jrt/src/com/yahoo/jrt/Supervisor.java index 62a2dce7871..14af463d84e 100644 --- a/jrt/src/com/yahoo/jrt/Supervisor.java +++ b/jrt/src/com/yahoo/jrt/Supervisor.java @@ -16,37 +16,10 @@ import java.util.HashMap; **/ public class Supervisor { - private class AddMethod implements Runnable { - private Method method; - AddMethod(Method method) { - this.method = method; - } - public void run() { - methodMap.put(method.name(), method); - } - } - - private class RemoveMethod implements Runnable { - private String methodName; - private Method method = null; - RemoveMethod(String methodName) { - this.methodName = methodName; - } - RemoveMethod(Method method) { - this.methodName = method.name(); - this.method = method; - } - public void run() { - Method m = methodMap.remove(methodName); - if (method != null && m != method) { - methodMap.put(method.name(), method); - } - } - } - private Transport transport; private SessionHandler sessionHandler = null; - private HashMap<String, Method> methodMap = new HashMap<>(); + private final Object methodMapLock = new Object(); + private volatile HashMap<String, Method> methodMap = new HashMap<>(); private int maxInputBufferSize = 0; private int maxOutputBufferSize = 0; @@ -122,7 +95,11 @@ public class Supervisor { * @param method the method to add **/ public void addMethod(Method method) { - transport.perform(new AddMethod(method)); + synchronized (methodMapLock) { + HashMap<String, Method> newMap = new HashMap<>(methodMap); + newMap.put(method.name(), method); + methodMap = newMap; + } } /** @@ -131,7 +108,11 @@ public class Supervisor { * @param methodName name of the method to remove **/ public void removeMethod(String methodName) { - transport.perform(new RemoveMethod(methodName)); + synchronized (methodMapLock) { + HashMap<String, Method> newMap = new HashMap<>(methodMap); + newMap.remove(methodName); + methodMap = newMap; + } } /** @@ -142,7 +123,12 @@ public class Supervisor { * @param method the method to remove **/ public void removeMethod(Method method) { - transport.perform(new RemoveMethod(method)); + synchronized (methodMapLock) { + HashMap<String, Method> newMap = new HashMap<>(methodMap); + if (newMap.remove(method.name()) == method) { + methodMap = newMap; + } + } } /** @@ -154,20 +140,7 @@ public class Supervisor { * @see #connect(com.yahoo.jrt.Spec, java.lang.Object) **/ public Target connect(Spec spec) { - return transport.connect(this, spec, null, false); - } - - /** - * Connect to the given address. The new {@link Target} will be - * associated with this Supervisor. This method will perform a - * synchronous connect in the calling thread. - * - * @return Target representing our end of the connection - * @param spec where to connect - * @see #connectSync(com.yahoo.jrt.Spec, java.lang.Object) - **/ - public Target connectSync(Spec spec) { - return transport.connect(this, spec, null, true); + return transport.connect(this, spec, null); } /** @@ -181,22 +154,7 @@ public class Supervisor { * @see Target#getContext **/ public Target connect(Spec spec, Object context) { - return transport.connect(this, spec, context, false); - } - - /** - * Connect to the given address. The new {@link Target} will be - * associated with this Supervisor and will have 'context' as - * application context. This method will perform a synchronous - * connect in the calling thread. - * - * @return Target representing our end of the connection - * @param spec where to connect - * @param context application context for the Target - * @see Target#getContext - **/ - public Target connectSync(Spec spec, Object context) { - return transport.connect(this, spec, context, true); + return transport.connect(this, spec, context); } /** @@ -219,7 +177,7 @@ public class Supervisor { * @param timeout request timeout in seconds **/ public void invokeBatch(Spec spec, Request req, double timeout) { - Target target = connectSync(spec); + Target target = connect(spec); try { target.invokeSync(req, timeout); } finally { @@ -312,7 +270,7 @@ public class Supervisor { } RequestPacket rp = (RequestPacket) packet; Request req = new Request(rp.methodName(), rp.parameters()); - Method method = methodMap.get(req.methodName()); + Method method = methodMap().get(req.methodName()); new InvocationServer(conn, req, method, packet.requestId(), packet.noReply()).invoke(); diff --git a/jrt/src/com/yahoo/jrt/Task.java b/jrt/src/com/yahoo/jrt/Task.java index 467aa7d6c2c..c5628c3040c 100644 --- a/jrt/src/com/yahoo/jrt/Task.java +++ b/jrt/src/com/yahoo/jrt/Task.java @@ -5,8 +5,8 @@ package com.yahoo.jrt; /** * A Task enables a Runnable to be scheduled for execution in the * transport thread some time in the future. Tasks are used internally - * to handle RPC timeouts. Use the {@link Transport#createTask - * Transport.createTask} method to create a task associated with a + * to handle RPC timeouts. Use the {@link TransportThread#createTask + * TransportThread.createTask} method to create a task associated with a * {@link Transport} object. Note that Task execution is designed to * be low-cost, so do not expect extreme accuracy. Also note that any * tasks that are pending execution when the owning {@link Transport} @@ -66,7 +66,7 @@ public class Task { * loop inside the owning {@link Transport} object checks for * tasks to run. If you have something that is even more urgent, * or something you need to be executed even if the {@link - * Transport} is shut down, use the {@link Transport#perform} + * Transport} is shut down, use the {@link TransportThread#perform} * method instead. * @see #kill **/ diff --git a/jrt/src/com/yahoo/jrt/Transport.java b/jrt/src/com/yahoo/jrt/Transport.java index 33ce6fe6ed0..f4eb1acd096 100644 --- a/jrt/src/com/yahoo/jrt/Transport.java +++ b/jrt/src/com/yahoo/jrt/Transport.java @@ -2,15 +2,14 @@ package com.yahoo.jrt; -import java.io.IOException; -import java.nio.channels.SelectionKey; -import java.nio.channels.Selector; import java.nio.channels.SocketChannel; +import java.util.ArrayList; import java.util.Iterator; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; - /** * The Transport class is the core needed to make your {@link * Supervisor} tick. It implements the reactor pattern to perform @@ -20,159 +19,17 @@ import java.util.logging.Logger; **/ public class Transport { - private static final int OPEN = 1; - private static final int CLOSING = 2; - private static final int CLOSED = 3; - - private class Run implements Runnable { - public void run() { - try { - Transport.this.run(); - } catch (Throwable problem) { - handleFailure(problem, Transport.this); - } - } - } - - private class AddConnectionCmd implements Runnable { - private Connection conn; - AddConnectionCmd(Connection conn) { this.conn = conn; } - public void run() { handleAddConnection(conn); } - } - - private class CloseConnectionCmd implements Runnable { - private Connection conn; - CloseConnectionCmd(Connection conn) { this.conn = conn; } - public void run() { handleCloseConnection(conn); } - } - - private class EnableWriteCmd implements Runnable { - private Connection conn; - EnableWriteCmd(Connection conn) { this.conn = conn; } - public void run() { handleEnableWrite(conn); } - } - - private class HandshakeWorkDoneCmd implements Runnable { - private Connection conn; - HandshakeWorkDoneCmd(Connection conn) { this.conn = conn; } - public void run() { handleHandshakeWorkDone(conn); } - } - - private class SyncCmd implements Runnable { - boolean done = false; - public synchronized void waitDone() { - while (!done) { - try { wait(); } catch (InterruptedException e) {} - } - } - public synchronized void run() { - done = true; - notify(); - } - } - private static Logger log = Logger.getLogger(Transport.class.getName()); - private FatalErrorHandler fatalHandler; // NB: this must be set first - private CryptoEngine cryptoEngine; - private Thread thread; - private Queue queue; - private Queue myQueue; - private Connector connector; - private Worker worker; - private Scheduler scheduler; - private int state; - private Selector selector; - private final TransportMetrics metrics = TransportMetrics.getInstance(); + private final FatalErrorHandler fatalHandler; // NB: this must be set first + private final CryptoEngine cryptoEngine; + private final Connector connector; + private final Worker worker; + private final AtomicInteger runCnt; - private void handleAddConnection(Connection conn) { - if (conn.isClosed()) { - if (conn.hasSocket()) { - worker.closeLater(conn); - } - return; - } - if (!conn.init(selector)) { - handleCloseConnection(conn); - } - } - - private void handleCloseConnection(Connection conn) { - if (conn.isClosed()) { - return; - } - conn.fini(); - if (conn.hasSocket()) { - worker.closeLater(conn); - } - } - - private void handleEnableWrite(Connection conn) { - if (conn.isClosed()) { - return; - } - conn.enableWrite(); - } - - private void handleHandshakeWorkDone(Connection conn) { - if (conn.isClosed()) { - return; - } - try { - conn.handleHandshakeWorkDone(); - } catch (IOException e) { - conn.setLostReason(e); - handleCloseConnection(conn); - } - } - - private boolean postCommand(Runnable cmd) { - boolean wakeup; - synchronized (this) { - if (state == CLOSED) { - return false; - } - wakeup = queue.isEmpty(); - queue.enqueue(cmd); - } - if (wakeup) { - selector.wakeup(); - } - return true; - } - - private void handleEvents() { - synchronized (this) { - queue.flush(myQueue); - } - while (!myQueue.isEmpty()) { - ((Runnable)myQueue.dequeue()).run(); - } - } - - private boolean handleIOEvents(Connection conn, - SelectionKey key) { - if (conn.isClosed()) { - return true; - } - if (key.isReadable()) { - try { - conn.handleReadEvent(); - } catch (IOException e) { - conn.setLostReason(e); - return false; - } - } - if (key.isWritable()) { - try { - conn.handleWriteEvent(); - } catch (IOException e) { - conn.setLostReason(e); - return false; - } - } - return true; - } + private final TransportMetrics metrics = TransportMetrics.getInstance(); + private final ArrayList<TransportThread> threads = new ArrayList<TransportThread>(); + private final Random rnd = new Random(); /** * Create a new Transport object with the given fatal error @@ -182,30 +39,33 @@ public class Transport { * * @param fatalHandler fatal error handler * @param cryptoEngine crypto engine to use + * @param numThreads number of {@link TransportThread}s. **/ - public Transport(FatalErrorHandler fatalHandler, CryptoEngine cryptoEngine) { + public Transport(FatalErrorHandler fatalHandler, CryptoEngine cryptoEngine, int numThreads) { synchronized (this) { this.fatalHandler = fatalHandler; // NB: this must be set first } this.cryptoEngine = cryptoEngine; - thread = new Thread(new Run(), "<jrt-transport>"); - queue = new Queue(); - myQueue = new Queue(); connector = new Connector(this); - worker = new Worker(this); - scheduler = new Scheduler(System.currentTimeMillis()); - state = OPEN; - try { - selector = Selector.open(); - } catch (Exception e) { - throw new Error("Could not open transport selector", e); + worker = new Worker(this); + runCnt = new AtomicInteger(numThreads); + for (int i = 0; i < numThreads; ++i) { + threads.add(new TransportThread(this)); } - thread.setDaemon(true); - thread.start(); } - public Transport(CryptoEngine cryptoEngine) { this(null, cryptoEngine); } - public Transport(FatalErrorHandler fatalHandler) { this(fatalHandler, CryptoEngine.createDefault()); } - public Transport() { this(null, CryptoEngine.createDefault()); } + public Transport(CryptoEngine cryptoEngine, int numThreads) { this(null, cryptoEngine, numThreads); } + public Transport(FatalErrorHandler fatalHandler, int numThreads) { this(fatalHandler, CryptoEngine.createDefault(), numThreads); } + public Transport(int numThreads) { this(null, CryptoEngine.createDefault(), numThreads); } + public Transport() { this(null, CryptoEngine.createDefault(), 1); } + + /** + * Select a random transport thread + * + * @return a random transport thread + **/ + public TransportThread selectThread() { + return threads.get(rnd.nextInt(threads.size())); + } /** * Use the underlying CryptoEngine to create a CryptoSocket. @@ -257,56 +117,15 @@ public class Transport { * @param owner the one calling this method * @param spec the address to connect to * @param context application context for the new connection - * @param sync perform a synchronous connect in the calling thread - * if this flag is set */ - Connection connect(Supervisor owner, Spec spec, Object context, boolean sync) { - Connection conn = new Connection(this, owner, spec, context); - if (sync) { - addConnection(conn.connect()); - } else { - connector.connectLater(conn); - } + Connection connect(Supervisor owner, Spec spec, Object context) { + Connection conn = new Connection(selectThread(), owner, spec, context); + connector.connectLater(conn); return conn; } - /** - * Add a connection to the set of connections handled by this - * Transport. Invoked by the {@link Connector} class. - * - * @param conn the connection to add - **/ - void addConnection(Connection conn) { - if (!postCommand(new AddConnectionCmd(conn))) { - perform(new CloseConnectionCmd(conn)); - } - } - - /** - * Request an asynchronous close of a connection. - * - * @param conn the connection to close - **/ - void closeConnection(Connection conn) { - postCommand(new CloseConnectionCmd(conn)); - } - - /** - * Request an asynchronous enabling of write events for a - * connection. - * - * @param conn the connection to enable write events for - **/ - void enableWrite(Connection conn) { - if (Thread.currentThread() == thread) { - handleEnableWrite(conn); - } else { - postCommand(new EnableWriteCmd(conn)); - } - } - - void handshakeWorkDone(Connection conn) { - postCommand(new HandshakeWorkDoneCmd(conn)); + void closeLater(Connection c) { + worker.closeLater(c); } /** @@ -320,126 +139,50 @@ public class Transport { } /** - * Create a {@link Task} that can be scheduled for execution in - * the transport thread. - * - * @return the newly created Task - * @param cmd what to run when the task is executed - **/ - public Task createTask(Runnable cmd) { - return new Task(scheduler, cmd); - } - - /** - * Perform the given command in such a way that it does not run - * concurrently with the transport thread or other commands - * performed by invoking this method. This method will continue to - * work even after the transport thread has been shut down. - * - * @param cmd the command to perform - **/ - public void perform(Runnable cmd) { - if (Thread.currentThread() == thread) { - cmd.run(); - return; - } - if (!postCommand(cmd)) { - join(); - synchronized (thread) { - cmd.run(); - } - } - } - - /** - * Synchronize with the transport thread. This method will block + * Synchronize with all transport threads. This method will block * until all commands issued before this method was invoked has - * completed. If the transport thread has been shut down (or is in + * completed. If a transport thread has been shut down (or is in * the progress of being shut down) this method will instead wait * for the transport thread to complete, since no more commands * will be performed, and waiting would be forever. Invoking this - * method from the transport thread is not a good idea. + * method from a transport thread is not a good idea. * * @return this object, to enable chaining **/ public Transport sync() { - SyncCmd cmd = new SyncCmd(); - if (postCommand(cmd)) { - cmd.waitDone(); - } else { - join(); + for (TransportThread thread: threads) { + thread.sync(); } return this; } - private void run() { - while (state == OPEN) { - - // perform I/O selection - try { - selector.select(100); - } catch (IOException e) { - log.log(Level.WARNING, "error during select", e); - } - - // handle internal events - handleEvents(); - - // handle I/O events - Iterator<SelectionKey> keys = selector.selectedKeys().iterator(); - while (keys.hasNext()) { - SelectionKey key = keys.next(); - Connection conn = (Connection) key.attachment(); - keys.remove(); - if (!handleIOEvents(conn, key)) { - handleCloseConnection(conn); - } - } - - // check scheduled tasks - scheduler.checkTasks(System.currentTimeMillis()); - } - connector.shutdown().waitDone(); - synchronized (this) { - state = CLOSED; - } - handleEvents(); - Iterator<SelectionKey> keys = selector.keys().iterator(); - while (keys.hasNext()) { - SelectionKey key = keys.next(); - Connection conn = (Connection) key.attachment(); - handleCloseConnection(conn); - } - try { selector.close(); } catch (Exception e) {} - worker.shutdown().join(); - connector.exit().join(); - try { cryptoEngine.close(); } catch (Exception e) {} - } - /** - * Initiate controlled shutdown of the transport thread. + * Initiate controlled shutdown of all transport threads. * * @return this object, to enable chaining with join **/ public Transport shutdown() { - synchronized (this) { - if (state == OPEN) { - state = CLOSING; - selector.wakeup(); - } + connector.shutdown().waitDone(); + for (TransportThread thread: threads) { + thread.shutdown(); } return this; } /** - * Wait for the transport thread to finish. + * Wait for all transport threads to finish. **/ public void join() { - while (true) { - try { - thread.join(); - return; - } catch (InterruptedException e) {} + for (TransportThread thread: threads) { + thread.join(); + } + } + + void notifyDone(TransportThread self) { + if (runCnt.decrementAndGet() == 0) { + worker.shutdown().join(); + connector.exit().join(); + try { cryptoEngine.close(); } catch (Exception e) {} } } diff --git a/jrt/src/com/yahoo/jrt/TransportThread.java b/jrt/src/com/yahoo/jrt/TransportThread.java new file mode 100644 index 00000000000..8f4f49b8888 --- /dev/null +++ b/jrt/src/com/yahoo/jrt/TransportThread.java @@ -0,0 +1,352 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + + +import java.io.IOException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.Iterator; +import java.util.logging.Level; +import java.util.logging.Logger; + + +/** + * A single reactor/scheduler thread inside a potentially + * multi-threaded {@link Transport}. + **/ +public class TransportThread { + + private static final int OPEN = 1; + private static final int CLOSING = 2; + private static final int CLOSED = 3; + + private class Run implements Runnable { + public void run() { + try { + TransportThread.this.run(); + } catch (Throwable problem) { + handleFailure(problem, TransportThread.this); + } + } + } + + private class AddConnectionCmd implements Runnable { + private Connection conn; + AddConnectionCmd(Connection conn) { this.conn = conn; } + public void run() { handleAddConnection(conn); } + } + + private class CloseConnectionCmd implements Runnable { + private Connection conn; + CloseConnectionCmd(Connection conn) { this.conn = conn; } + public void run() { handleCloseConnection(conn); } + } + + private class EnableWriteCmd implements Runnable { + private Connection conn; + EnableWriteCmd(Connection conn) { this.conn = conn; } + public void run() { handleEnableWrite(conn); } + } + + private class HandshakeWorkDoneCmd implements Runnable { + private Connection conn; + HandshakeWorkDoneCmd(Connection conn) { this.conn = conn; } + public void run() { handleHandshakeWorkDone(conn); } + } + + private class SyncCmd implements Runnable { + boolean done = false; + public synchronized void waitDone() { + while (!done) { + try { wait(); } catch (InterruptedException e) {} + } + } + public synchronized void run() { + done = true; + notify(); + } + } + + private static Logger log = Logger.getLogger(TransportThread.class.getName()); + + private final Transport parent; + private final Thread thread; + private final Queue queue; + private final Queue myQueue; + private final Scheduler scheduler; + private int state; + private final Selector selector; + + private void handleAddConnection(Connection conn) { + if (conn.isClosed()) { + if (conn.hasSocket()) { + parent.closeLater(conn); + } + return; + } + if (!conn.init(selector)) { + handleCloseConnection(conn); + } + } + + private void handleCloseConnection(Connection conn) { + if (conn.isClosed()) { + return; + } + conn.fini(); + if (conn.hasSocket()) { + parent.closeLater(conn); + } + } + + private void handleEnableWrite(Connection conn) { + if (conn.isClosed()) { + return; + } + conn.enableWrite(); + } + + private void handleHandshakeWorkDone(Connection conn) { + if (conn.isClosed()) { + return; + } + try { + conn.handleHandshakeWorkDone(); + } catch (IOException e) { + conn.setLostReason(e); + handleCloseConnection(conn); + } + } + + private boolean postCommand(Runnable cmd) { + boolean wakeup; + synchronized (this) { + if (state == CLOSED) { + return false; + } + wakeup = queue.isEmpty(); + queue.enqueue(cmd); + } + if (wakeup) { + selector.wakeup(); + } + return true; + } + + private void handleEvents() { + synchronized (this) { + queue.flush(myQueue); + } + while (!myQueue.isEmpty()) { + ((Runnable)myQueue.dequeue()).run(); + } + } + + private boolean handleIOEvents(Connection conn, + SelectionKey key) { + if (conn.isClosed()) { + return true; + } + if (key.isReadable()) { + try { + conn.handleReadEvent(); + } catch (IOException e) { + conn.setLostReason(e); + return false; + } + } + if (key.isWritable()) { + try { + conn.handleWriteEvent(); + } catch (IOException e) { + conn.setLostReason(e); + return false; + } + } + return true; + } + + TransportThread(Transport transport) { + parent = transport; + thread = new Thread(new Run(), "<jrt-transport>"); + queue = new Queue(); + myQueue = new Queue(); + scheduler = new Scheduler(System.currentTimeMillis()); + state = OPEN; + try { + selector = Selector.open(); + } catch (Exception e) { + throw new Error("Could not open transport selector", e); + } + thread.setDaemon(true); + thread.start(); + } + + public Transport transport() { + return parent; + } + + /** + * Proxy method used to dispatch fatal errors to the enclosing + * Transport. + * + * @param problem the throwable causing the failure + * @param context the object owning the crashing thread + **/ + void handleFailure(Throwable problem, Object context) { + parent.handleFailure(problem, context); + } + + /** + * Add a connection to the set of connections handled by this + * TransportThread. Invoked by the {@link Connector} class. + * + * @param conn the connection to add + **/ + void addConnection(Connection conn) { + if (!postCommand(new AddConnectionCmd(conn))) { + perform(new CloseConnectionCmd(conn)); + } + } + + /** + * Request an asynchronous close of a connection. + * + * @param conn the connection to close + **/ + void closeConnection(Connection conn) { + postCommand(new CloseConnectionCmd(conn)); + } + + /** + * Request an asynchronous enabling of write events for a + * connection. + * + * @param conn the connection to enable write events for + **/ + void enableWrite(Connection conn) { + if (Thread.currentThread() == thread) { + handleEnableWrite(conn); + } else { + postCommand(new EnableWriteCmd(conn)); + } + } + + void handshakeWorkDone(Connection conn) { + postCommand(new HandshakeWorkDoneCmd(conn)); + } + + /** + * Create a {@link Task} that can be scheduled for execution in + * the transport thread. + * + * @return the newly created Task + * @param cmd what to run when the task is executed + **/ + public Task createTask(Runnable cmd) { + return new Task(scheduler, cmd); + } + + /** + * Perform the given command in such a way that it does not run + * concurrently with the transport thread or other commands + * performed by invoking this method. This method will continue to + * work even after the transport thread has been shut down. + * + * @param cmd the command to perform + **/ + public void perform(Runnable cmd) { + if (Thread.currentThread() == thread) { + cmd.run(); + return; + } + if (!postCommand(cmd)) { + join(); + synchronized (thread) { + cmd.run(); + } + } + } + + /** + * Synchronize with the transport thread. This method will block + * until all commands issued before this method was invoked has + * completed. If the transport thread has been shut down (or is in + * the progress of being shut down) this method will instead wait + * for the transport thread to complete, since no more commands + * will be performed, and waiting would be forever. Invoking this + * method from the transport thread is not a good idea. + * + * @return this object, to enable chaining + **/ + public TransportThread sync() { + SyncCmd cmd = new SyncCmd(); + if (postCommand(cmd)) { + cmd.waitDone(); + } else { + join(); + } + return this; + } + + private void run() { + while (state == OPEN) { + + // perform I/O selection + try { + selector.select(100); + } catch (IOException e) { + log.log(Level.WARNING, "error during select", e); + } + + // handle internal events + handleEvents(); + + // handle I/O events + Iterator<SelectionKey> keys = selector.selectedKeys().iterator(); + while (keys.hasNext()) { + SelectionKey key = keys.next(); + Connection conn = (Connection) key.attachment(); + keys.remove(); + if (!handleIOEvents(conn, key)) { + handleCloseConnection(conn); + } + } + + // check scheduled tasks + scheduler.checkTasks(System.currentTimeMillis()); + } + synchronized (this) { + state = CLOSED; + } + handleEvents(); + Iterator<SelectionKey> keys = selector.keys().iterator(); + while (keys.hasNext()) { + SelectionKey key = keys.next(); + Connection conn = (Connection) key.attachment(); + handleCloseConnection(conn); + } + try { selector.close(); } catch (Exception e) {} + parent.notifyDone(this); + } + + TransportThread shutdown() { + synchronized (this) { + if (state == OPEN) { + state = CLOSING; + selector.wakeup(); + } + } + return this; + } + + void join() { + while (true) { + try { + thread.join(); + return; + } catch (InterruptedException e) {} + } + } +} diff --git a/jrt/src/com/yahoo/jrt/Worker.java b/jrt/src/com/yahoo/jrt/Worker.java index 986bca864f3..39c0e6773b2 100644 --- a/jrt/src/com/yahoo/jrt/Worker.java +++ b/jrt/src/com/yahoo/jrt/Worker.java @@ -33,7 +33,7 @@ class Worker { } public void run() { connection.doHandshakeWork(); - connection.transport().handshakeWorkDone(connection); + connection.transportThread().handshakeWorkDone(connection); } } diff --git a/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java b/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java index 22abddfe924..9ed8eafbd30 100644 --- a/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java +++ b/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java @@ -10,6 +10,7 @@ import com.yahoo.jrt.Spec; import com.yahoo.jrt.Supervisor; import com.yahoo.jrt.Target; import com.yahoo.jrt.Task; +import com.yahoo.jrt.TransportThread; import com.yahoo.jrt.Values; import java.util.ArrayList; @@ -40,6 +41,7 @@ public class Mirror implements IMirror { private boolean requestDone = false; private AtomicReference<Entry[]> specs = new AtomicReference<>(new Entry[0]); private int specsGeneration = 0; + private final TransportThread transportThread; private final Task updateTask; private final RequestWaiter reqWait; private Target target = null; @@ -57,7 +59,8 @@ public class Mirror implements IMirror { this.orb = orb; this.slobroks = slobroks; this.backOff = bop; - updateTask = orb.transport().createTask(this::checkForUpdate); + transportThread = orb.transport().selectThread(); + updateTask = transportThread.createTask(this::checkForUpdate); reqWait = new RequestWaiter() { public void handleRequestDone(Request req) { requestDone = true; @@ -84,7 +87,7 @@ public class Mirror implements IMirror { */ public void shutdown() { updateTask.kill(); - orb.transport().perform(this::handleShutdown); + transportThread.perform(this::handleShutdown); } @Override diff --git a/jrt/src/com/yahoo/jrt/slobrok/api/Register.java b/jrt/src/com/yahoo/jrt/slobrok/api/Register.java index 0560510186c..713cecc62d1 100644 --- a/jrt/src/com/yahoo/jrt/slobrok/api/Register.java +++ b/jrt/src/com/yahoo/jrt/slobrok/api/Register.java @@ -12,6 +12,7 @@ import com.yahoo.jrt.StringValue; import com.yahoo.jrt.Supervisor; import com.yahoo.jrt.Target; import com.yahoo.jrt.Task; +import com.yahoo.jrt.TransportThread; import com.yahoo.jrt.Values; import java.util.ArrayList; @@ -45,6 +46,7 @@ public class Register { private List<String> names = new ArrayList<>(); private List<String> pending = new ArrayList<>(); private List<String> unreg = new ArrayList<>(); + private final TransportThread transportThread; private Task updateTask = null; private RequestWaiter reqWait = null; private Target target = null; @@ -79,9 +81,8 @@ public class Register { this.slobroks = slobroks; this.backOff = bop; mySpec = spec.toString(); - updateTask = orb.transport().createTask(new Runnable() { - public void run() { handleUpdate(); } - }); + transportThread = orb.transport().selectThread(); + updateTask = transportThread.createTask(this::handleUpdate); reqWait = new RequestWaiter() { public void handleRequestDone(Request req) { reqDone = true; @@ -142,9 +143,7 @@ public class Register { */ public void shutdown() { updateTask.kill(); - orb.transport().perform(new Runnable() { - public void run() { handleShutdown(); } - }); + transportThread.perform(this::handleShutdown); } /** diff --git a/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java b/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java index 4b967f7f3d7..6ce8f3d1227 100644 --- a/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java +++ b/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java @@ -24,7 +24,8 @@ public class Slobrok { public Slobrok(int port) throws ListenFailedException { - orb = new Supervisor(new Transport()); + // NB: rpc must be single-threaded + orb = new Supervisor(new Transport(1)); registerMethods(); try { listener = orb.listen(new Spec(port)); @@ -245,7 +246,7 @@ public class Slobrok { public FetchMirror(Request req, int timeout) { req.detach(); this.req = req; - task = orb.transport().createTask(this); + task = orb.transport().selectThread().createTask(this); task.schedule(((double)timeout)/1000.0); } public void run() { // timeout diff --git a/jrt/tests/com/yahoo/jrt/EchoTest.java b/jrt/tests/com/yahoo/jrt/EchoTest.java index 16f18afb58c..67544d3f1d4 100644 --- a/jrt/tests/com/yahoo/jrt/EchoTest.java +++ b/jrt/tests/com/yahoo/jrt/EchoTest.java @@ -91,8 +91,8 @@ public class EchoTest { public void setUp() throws ListenFailedException { metrics = TransportMetrics.getInstance(); startSnapshot = metrics.snapshot(); - server = new Supervisor(new Transport(crypto)); - client = new Supervisor(new Transport(crypto)); + server = new Supervisor(new Transport(crypto, 1)); + client = new Supervisor(new Transport(crypto, 1)); acceptor = server.listen(new Spec(0)); target = client.connect(new Spec("localhost", acceptor.port())); server.addMethod(new Method("echo", "*", "*", this, "rpc_echo")); diff --git a/jrt/tests/com/yahoo/jrt/LatencyTest.java b/jrt/tests/com/yahoo/jrt/LatencyTest.java index a1f71bda013..e8cd6cdc017 100644 --- a/jrt/tests/com/yahoo/jrt/LatencyTest.java +++ b/jrt/tests/com/yahoo/jrt/LatencyTest.java @@ -5,6 +5,8 @@ package com.yahoo.jrt; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.logging.Logger; import static com.yahoo.jrt.CryptoUtils.createTestTlsContext; @@ -14,71 +16,159 @@ import static org.junit.Assert.assertTrue; public class LatencyTest { private static final Logger log = Logger.getLogger(LatencyTest.class.getName()); - private static class Server implements AutoCloseable { - private Supervisor orb; - private Acceptor acceptor; - public Server(CryptoEngine crypto) throws ListenFailedException { - orb = new Supervisor(new Transport(crypto)); - acceptor = orb.listen(new Spec(0)); - orb.addMethod(new Method("inc", "i", "i", this, "rpc_inc")); + private static class Network implements AutoCloseable { + private final Supervisor server; + private final Supervisor client; + private final Acceptor acceptor; + public Network(CryptoEngine crypto, int threads) throws ListenFailedException { + server = new Supervisor(new Transport(crypto, threads)); + client = new Supervisor(new Transport(crypto, threads)); + server.addMethod(new Method("inc", "i", "i", this, "rpc_inc")); + acceptor = server.listen(new Spec(0)); } public Target connect() { - return orb.connect(new Spec("localhost", acceptor.port())); + return client.connect(new Spec("localhost", acceptor.port())); } public void rpc_inc(Request req) { req.returnValues().add(new Int32Value(req.parameters().get(0).asInt32() + 1)); } public void close() { acceptor.shutdown().join(); - orb.transport().shutdown().join(); + client.transport().shutdown().join(); + server.transport().shutdown().join(); } } - private void measureLatency(String prefix, Server server, boolean reconnect) { - int value = 100; - List<Double> list = new ArrayList<>(); - Target target = server.connect(); - for (int i = 0; i < 64; ++i) { - long before = System.nanoTime(); - if (reconnect) { + private static class Client { + + public static class Result { + public final double latency; + public final double throughput; + + public Result(double ms, double cnt) { + latency = ms; + throughput = cnt; + } + + public Result(Result[] results) { + double ms = 0.0; + double cnt = 0.0; + for (Result r: results) { + ms += r.latency; + cnt += r.throughput; + } + latency = (ms / results.length); + throughput = cnt; + } + } + + private final boolean reconnect; + private final Network network; + private final CyclicBarrier barrier; + private final CountDownLatch latch; + private final Throwable[] issues; + private final Result[] results; + + private void run(int threadId) { + try { + barrier.await(); + int value = 100; + final int warmupCnt = 10; + final int benchmarkCnt = 50; + final int cooldownCnt = 10; + final int totalReqs = (warmupCnt + benchmarkCnt + cooldownCnt); + long t1 = 0; + long t2 = 0; + List<Double> list = new ArrayList<>(); + Target target = network.connect(); + for (int i = 0; i < totalReqs; ++i) { + long before = System.nanoTime(); + if (i == warmupCnt) { + t1 = before; + } + if (i == (warmupCnt + benchmarkCnt)) { + t2 = before; + } + if (reconnect) { + target.close(); + target = network.connect(); + } + Request req = new Request("inc"); + req.parameters().add(new Int32Value(value)); + target.invokeSync(req, 60.0); + long duration = System.nanoTime() - before; + assertTrue(req.checkReturnTypes("i")); + assertEquals(value + 1, req.returnValues().get(0).asInt32()); + value++; + list.add(duration / 1000000.0); + } target.close(); - target = server.connect(); + Collections.sort(list); + double benchTime = (t2 - t1) / 1000000000.0; + results[threadId] = new Result(list.get(list.size() / 2), benchmarkCnt / benchTime); + } catch (Throwable issue) { + issues[threadId] = issue; + } finally { + latch.countDown(); } - Request req = new Request("inc"); - req.parameters().add(new Int32Value(value)); - target.invokeSync(req, 60.0); - assertTrue(req.checkReturnTypes("i")); - assertEquals(value + 1, req.returnValues().get(0).asInt32()); - value++; - long duration = System.nanoTime() - before; - list.add(duration / 1000000.0); } - target.close(); - Collections.sort(list); - log.info(prefix + "invocation latency: " + list.get(list.size() / 2) + " ms"); + + public Client(boolean reconnect, Network network, int numThreads) { + this.reconnect = reconnect; + this.network = network; + this.barrier = new CyclicBarrier(numThreads); + this.latch = new CountDownLatch(numThreads); + this.issues = new Throwable[numThreads]; + this.results = new Result[numThreads]; + } + + public void measureLatency(String prefix) throws Throwable { + for (int i = 0; i < results.length; ++i) { + final int threadId = i; + new Thread(()->run(threadId)).start(); + } + latch.await(); + for (Throwable issue: issues) { + if (issue != null) { + throw(issue); + } + } + Result result = new Result(results); + log.info(prefix + "latency: " + result.latency + " ms, throughput: " + result.throughput + " req/s"); + } } @org.junit.Test - public void testNullCryptoLatency() throws ListenFailedException { - try (Server server = new Server(new NullCryptoEngine())) { - measureLatency("[null crypto, no reconnect] ", server, false); - measureLatency("[null crypto, reconnect] ", server, true); + public void testNullCryptoLatency() throws Throwable { + try (Network network = new Network(new NullCryptoEngine(), 1)) { + new Client(false, network, 1).measureLatency("[null crypto, no reconnect] "); + new Client(true, network, 1).measureLatency("[null crypto, reconnect] "); } } @org.junit.Test - public void testXorCryptoLatency() throws ListenFailedException { - try (Server server = new Server(new XorCryptoEngine())) { - measureLatency("[xor crypto, no reconnect] ", server, false); - measureLatency("[xor crypto, reconnect] ", server, true); + public void testXorCryptoLatency() throws Throwable { + try (Network network = new Network(new XorCryptoEngine(), 1)) { + new Client(false, network, 1).measureLatency("[xor crypto, no reconnect] "); + new Client(true, network, 1).measureLatency("[xor crypto, reconnect] "); } } @org.junit.Test - public void testTlsCryptoLatency() throws ListenFailedException { - try (Server server = new Server(new TlsCryptoEngine(createTestTlsContext()))) { - measureLatency("[tls crypto, no reconnect] ", server, false); - measureLatency("[tls crypto, reconnect] ", server, true); + public void testTlsCryptoLatency() throws Throwable { + try (Network network = new Network(new TlsCryptoEngine(createTestTlsContext()), 1)) { + new Client(false, network, 1).measureLatency("[tls crypto, no reconnect] "); + new Client(true, network, 1).measureLatency("[tls crypto, reconnect] "); + } + } + + @org.junit.Test + public void testTransportThreadScaling() throws Throwable { + try (Network network = new Network(new NullCryptoEngine(), 1)) { + new Client(false, network, 64).measureLatency("[64 clients, 1/1 transport] "); + } + try (Network network = new Network(new NullCryptoEngine(), 4)) { + new Client(false, network, 64).measureLatency("[64 clients, 4/4 transport] "); } } } diff --git a/jrt/tests/com/yahoo/jrt/SessionTest.java b/jrt/tests/com/yahoo/jrt/SessionTest.java index 6f070959d7a..dc33af96e44 100644 --- a/jrt/tests/com/yahoo/jrt/SessionTest.java +++ b/jrt/tests/com/yahoo/jrt/SessionTest.java @@ -122,9 +122,9 @@ public class SessionTest implements SessionHandler { @Before public void setUp() throws ListenFailedException { Session.reset(); - server = new Test.Orb(new Transport(crypto)); + server = new Test.Orb(new Transport(crypto, 1)); server.setSessionHandler(this); - client = new Test.Orb(new Transport(crypto)); + client = new Test.Orb(new Transport(crypto, 1)); client.setSessionHandler(this); acceptor = server.listen(new Spec(0)); target = client.connect(new Spec("localhost", acceptor.port()), diff --git a/jrt_test/src/tests/mandatory-methods/ref.txt b/jrt_test/src/tests/mandatory-methods/ref.txt index 785181bca6d..7a4943b8edf 100644 --- a/jrt_test/src/tests/mandatory-methods/ref.txt +++ b/jrt_test/src/tests/mandatory-methods/ref.txt @@ -1,3 +1,7 @@ +METHOD frt.rpc.ping + DESCRIPTION: + Method that may be used to check if the server is online + METHOD frt.rpc.getMethodInfo DESCRIPTION: Obtain detailed information about a single method @@ -20,7 +24,3 @@ METHOD frt.rpc.getMethodList [S][params] Method parameter types [S][return] Method return types -METHOD frt.rpc.ping - DESCRIPTION: - Method that may be used to check if the server is online - diff --git a/logserver/src/test/java/ai/vespa/logserver/protocol/ArchiveLogMessagesMethodTest.java b/logserver/src/test/java/ai/vespa/logserver/protocol/ArchiveLogMessagesMethodTest.java index 847975bf2d9..93160c2a7a2 100644 --- a/logserver/src/test/java/ai/vespa/logserver/protocol/ArchiveLogMessagesMethodTest.java +++ b/logserver/src/test/java/ai/vespa/logserver/protocol/ArchiveLogMessagesMethodTest.java @@ -57,7 +57,7 @@ public class ArchiveLogMessagesMethodTest { TestClient(int logserverPort) { this.supervisor = new Supervisor(new Transport()); - this.target = supervisor.connectSync(new Spec(logserverPort)); + this.target = supervisor.connect(new Spec(logserverPort)); } void logMessages(List<LogMessage> messages) { diff --git a/messagebus/src/main/java/com/yahoo/messagebus/network/rpc/RPCNetwork.java b/messagebus/src/main/java/com/yahoo/messagebus/network/rpc/RPCNetwork.java index 4ba8f6f0312..6b206435fa7 100644 --- a/messagebus/src/main/java/com/yahoo/messagebus/network/rpc/RPCNetwork.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/network/rpc/RPCNetwork.java @@ -197,9 +197,7 @@ public class RPCNetwork implements Network, MethodHandler { @Override public void sync() { - SyncTask sh = new SyncTask(); - orb.transport().perform(sh); - sh.await(); + orb.transport().sync(); } @Override @@ -446,29 +444,6 @@ public class RPCNetwork implements Network, MethodHandler { } /** - * Implements a helper class for {@link RPCNetwork#sync()}. It provides a blocking method {@link #await()} that will - * wait until the internal state of this object is set to 'done'. By scheduling this task in the network thread and - * then calling this method, we achieve handshaking with the network thread. - */ - private static class SyncTask implements Runnable { - - final CountDownLatch latch = new CountDownLatch(1); - - @Override - public void run() { - latch.countDown(); - } - - public void await() { - try { - latch.await(); - } catch (InterruptedException e) { - // ignore - } - } - } - - /** * Implements a helper class for {@link RPCNetwork#send(com.yahoo.messagebus.Message, java.util.List)}. It works by * encapsulating all the data required for sending a message, but postponing the call to {@link * RPCNetwork#send(com.yahoo.messagebus.network.rpc.RPCNetwork.SendContext)} until the version of all targets have @@ -523,7 +498,7 @@ public class RPCNetwork implements Network, MethodHandler { TargetPoolTask(RPCTargetPool pool, Supervisor orb) { this.pool = pool; - this.jrtTask = orb.transport().createTask(this); + this.jrtTask = orb.transport().selectThread().createTask(this); this.jrtTask.schedule(1.0); } |