diff options
Diffstat (limited to 'jrt')
29 files changed, 279 insertions, 149 deletions
diff --git a/jrt/src/com/yahoo/jrt/Acceptor.java b/jrt/src/com/yahoo/jrt/Acceptor.java index 1cbfd36e8c5..14b35c5893f 100644 --- a/jrt/src/com/yahoo/jrt/Acceptor.java +++ b/jrt/src/com/yahoo/jrt/Acceptor.java @@ -1,14 +1,12 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jrt; - import java.nio.channels.ClosedChannelException; import java.nio.channels.ServerSocketChannel; import java.util.concurrent.CountDownLatch; import java.util.logging.Level; import java.util.logging.Logger; - /** * A class used to listen on a network socket. A separate thread is * used to accept connections and register them with the underlying diff --git a/jrt/src/com/yahoo/jrt/Connection.java b/jrt/src/com/yahoo/jrt/Connection.java index 00aceb7e352..8a185907aae 100644 --- a/jrt/src/com/yahoo/jrt/Connection.java +++ b/jrt/src/com/yahoo/jrt/Connection.java @@ -1,7 +1,10 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jrt; +import com.yahoo.security.tls.ConnectionAuthContext; + import java.io.IOException; +import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; @@ -9,7 +12,6 @@ import java.nio.channels.SocketChannel; import java.util.HashMap; import java.util.IdentityHashMap; import java.util.Map; -import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -436,9 +438,16 @@ class Connection extends Target { } @Override - public Optional<SecurityContext> getSecurityContext() { - return Optional.ofNullable(socket) - .flatMap(CryptoSocket::getSecurityContext); + public ConnectionAuthContext connectionAuthContext() { + if (socket == null) throw new IllegalStateException("Not connected"); + return socket.connectionAuthContext(); + } + + @Override + public Spec peerSpec() { + if (socket == null) throw new IllegalStateException("Not connected"); + InetSocketAddress addr = (InetSocketAddress) socket.channel().socket().getRemoteSocketAddress(); + return new Spec(addr.getHostString(), addr.getPort()); } public boolean isClient() { @@ -455,8 +464,7 @@ class Connection extends Target { waiter.waitDone(); } - public void invokeAsync(Request req, double timeout, - RequestWaiter waiter) { + public void invokeAsync(Request req, double timeout, RequestWaiter waiter) { if (timeout < 0.0) { timeout = 0.0; } diff --git a/jrt/src/com/yahoo/jrt/CryptoSocket.java b/jrt/src/com/yahoo/jrt/CryptoSocket.java index 78308b76624..e30579d2bdc 100644 --- a/jrt/src/com/yahoo/jrt/CryptoSocket.java +++ b/jrt/src/com/yahoo/jrt/CryptoSocket.java @@ -2,10 +2,11 @@ package com.yahoo.jrt; +import com.yahoo.security.tls.ConnectionAuthContext; + import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; -import java.util.Optional; /** @@ -102,11 +103,6 @@ public interface CryptoSocket { **/ public void dropEmptyBuffers(); - /** - * Returns the security context for the current connection (given handshake completed), - * or empty if the current connection is not secure. - */ - default public Optional<SecurityContext> getSecurityContext() { - return Optional.empty(); - } + /** Returns the auth context for the current connection (given handshake completed) */ + default ConnectionAuthContext connectionAuthContext() { return ConnectionAuthContext.defaultAllCapabilities(); } } diff --git a/jrt/src/com/yahoo/jrt/ErrorCode.java b/jrt/src/com/yahoo/jrt/ErrorCode.java index beaabcea316..8e129cfef98 100644 --- a/jrt/src/com/yahoo/jrt/ErrorCode.java +++ b/jrt/src/com/yahoo/jrt/ErrorCode.java @@ -49,4 +49,7 @@ public class ErrorCode /** Method failed (111) **/ public static final int METHOD_FAILED = 111; + + /** Permission denied (112) **/ + public static final int PERMISSION_DENIED = 112; } diff --git a/jrt/src/com/yahoo/jrt/InvocationServer.java b/jrt/src/com/yahoo/jrt/InvocationServer.java index 9df92eb20a6..7704c0019ed 100644 --- a/jrt/src/com/yahoo/jrt/InvocationServer.java +++ b/jrt/src/com/yahoo/jrt/InvocationServer.java @@ -31,7 +31,11 @@ class InvocationServer { public void invoke() { if (method != null) { if (method.checkParameters(request)) { - method.invoke(request); + if (method.requestAccessFilter().allow(request)) { + method.invoke(request); + } else { + request.setError(ErrorCode.PERMISSION_DENIED, "Permission denied"); + } } else { request.setError(ErrorCode.WRONG_PARAMS, "Parameters in " + request + " does not match " + method); } diff --git a/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java index df01f4f2fa7..ab9d78d2676 100644 --- a/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java +++ b/jrt/src/com/yahoo/jrt/MaybeTlsCryptoSocket.java @@ -1,10 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jrt; +import com.yahoo.security.tls.ConnectionAuthContext; + import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.SocketChannel; -import java.util.Optional; /** * A crypto socket for the server side of a connection that @@ -130,5 +131,5 @@ public class MaybeTlsCryptoSocket implements CryptoSocket { @Override public int write(ByteBuffer src) throws IOException { return socket.write(src); } @Override public FlushResult flush() throws IOException { return socket.flush(); } @Override public void dropEmptyBuffers() { socket.dropEmptyBuffers(); } - @Override public Optional<SecurityContext> getSecurityContext() { return Optional.ofNullable(socket).flatMap(CryptoSocket::getSecurityContext); } + @Override public ConnectionAuthContext connectionAuthContext() { return socket.connectionAuthContext(); } } diff --git a/jrt/src/com/yahoo/jrt/Method.java b/jrt/src/com/yahoo/jrt/Method.java index 4fc9f0714da..89c66747e0b 100644 --- a/jrt/src/com/yahoo/jrt/Method.java +++ b/jrt/src/com/yahoo/jrt/Method.java @@ -40,6 +40,8 @@ public class Method { private String[] returnName; private String[] returnDesc; + private RequestAccessFilter filter = RequestAccessFilter.ALLOW_ALL; + private static final String undocumented = "???"; @@ -147,6 +149,10 @@ public class Method { return this; } + public Method requestAccessFilter(RequestAccessFilter filter) { this.filter = filter; return this; } + + public RequestAccessFilter requestAccessFilter() { return filter; } + /** * Obtain the name of a parameter * diff --git a/jrt/src/com/yahoo/jrt/RequestAccessFilter.java b/jrt/src/com/yahoo/jrt/RequestAccessFilter.java new file mode 100644 index 00000000000..6701436d6ce --- /dev/null +++ b/jrt/src/com/yahoo/jrt/RequestAccessFilter.java @@ -0,0 +1,17 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +/** + * Request access filter is invoked before any call to {@link Method#invoke(Request)}. + * If {@link #allow(Request)} returns false, the method is not invoked, and the request is failed with error + * {@link ErrorCode#PERMISSION_DENIED}. + * + * @author bjorncs + */ +public interface RequestAccessFilter { + + RequestAccessFilter ALLOW_ALL = __ -> true; + + boolean allow(Request r); + +} diff --git a/jrt/src/com/yahoo/jrt/RequireCapabilitiesFilter.java b/jrt/src/com/yahoo/jrt/RequireCapabilitiesFilter.java new file mode 100644 index 00000000000..9bb497e96ed --- /dev/null +++ b/jrt/src/com/yahoo/jrt/RequireCapabilitiesFilter.java @@ -0,0 +1,34 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import com.yahoo.security.tls.Capability; +import com.yahoo.security.tls.CapabilitySet; +import com.yahoo.security.tls.MissingCapabilitiesException; + +/** + * @author bjorncs + */ +public class RequireCapabilitiesFilter implements RequestAccessFilter { + + private final CapabilitySet requiredCapabilities; + + public RequireCapabilitiesFilter(CapabilitySet requiredCapabilities) { + this.requiredCapabilities = requiredCapabilities; + } + + public RequireCapabilitiesFilter(Capability... requiredCapabilities) { + this(CapabilitySet.from(requiredCapabilities)); + } + + @Override + public boolean allow(Request r) { + try { + r.target().connectionAuthContext() + .verifyCapabilities(requiredCapabilities, "RPC", r.methodName(), r.target().peerSpec().toString()); + return true; + } catch (MissingCapabilitiesException e) { + return false; + } + } + +} diff --git a/jrt/src/com/yahoo/jrt/SecurityContext.java b/jrt/src/com/yahoo/jrt/SecurityContext.java deleted file mode 100644 index 4eef99cb93f..00000000000 --- a/jrt/src/com/yahoo/jrt/SecurityContext.java +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jrt; - -import java.security.cert.X509Certificate; -import java.util.List; - -/** - * @author bjorncs - */ -public class SecurityContext { - - private final List<X509Certificate> peerCertificateChain; - - public SecurityContext(List<X509Certificate> peerCertificateChain) { - this.peerCertificateChain = peerCertificateChain; - } - - /** - * @return the peer certificate chain if the peer was authenticated, empty list if not. - */ - public List<X509Certificate> peerCertificateChain() { - return peerCertificateChain; - } -} diff --git a/jrt/src/com/yahoo/jrt/Target.java b/jrt/src/com/yahoo/jrt/Target.java index a59aa341fe0..6cb9d432e03 100644 --- a/jrt/src/com/yahoo/jrt/Target.java +++ b/jrt/src/com/yahoo/jrt/Target.java @@ -1,7 +1,9 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jrt; -import java.util.Optional; +import com.yahoo.security.tls.ConnectionAuthContext; + +import java.time.Duration; /** * A Target represents a connection endpoint with RPC @@ -69,9 +71,13 @@ public abstract class Target { public Exception getConnectionLostReason() { return null; } /** - * Returns the security context associated with this target, or empty if no connection or is insecure. + * Returns the connection auth context associated with this target. */ - public abstract Optional<SecurityContext> getSecurityContext(); + public abstract ConnectionAuthContext connectionAuthContext(); + + + /** @return address spec of socket peer */ + public abstract Spec peerSpec(); /** * Check if this target represents the client side of a @@ -97,6 +103,10 @@ public abstract class Target { */ public abstract void invokeSync(Request req, double timeout); + public void invokeSync(Request req, Duration timeout) { + invokeSync(req, toSeconds(timeout)); + } + /** * Invoke a request on this target and let the completion be * signalled with a callback. @@ -105,8 +115,15 @@ public abstract class Target { * @param timeout timeout in seconds * @param waiter callback handler */ - public abstract void invokeAsync(Request req, double timeout, - RequestWaiter waiter); + public abstract void invokeAsync(Request req, double timeout, RequestWaiter waiter); + + public void invokeAsync(Request req, Duration timeout, RequestWaiter waiter) { + invokeAsync(req, toSeconds(timeout), waiter); + } + + private static double toSeconds(Duration duration) { + return ((double)duration.toMillis())/1000.0; + } /** * Invoke a request on this target, but ignore the return diff --git a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java index a899938dd45..d83c1ee8baa 100644 --- a/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java +++ b/jrt/src/com/yahoo/jrt/TlsCryptoSocket.java @@ -1,27 +1,23 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jrt; -import com.yahoo.security.tls.authz.AuthorizationResult; -import com.yahoo.security.tls.authz.PeerAuthorizerTrustManager; +import com.yahoo.security.tls.ConnectionAuthContext; +import com.yahoo.security.tls.PeerAuthorizationFailedException; +import com.yahoo.security.tls.TransportSecurityUtils; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; -import java.security.cert.X509Certificate; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; +import java.util.Objects; import java.util.logging.Logger; -import static java.util.stream.Collectors.toList; import static javax.net.ssl.SSLEngineResult.Status; /** @@ -46,7 +42,7 @@ public class TlsCryptoSocket implements CryptoSocket { private int sessionApplicationBufferSize; private ByteBuffer handshakeDummyBuffer; private HandshakeState handshakeState; - private AuthorizationResult authorizationResult; + private ConnectionAuthContext authContext; public TlsCryptoSocket(SocketChannel channel, SSLEngine sslEngine) { this.channel = channel; @@ -102,15 +98,6 @@ public class TlsCryptoSocket implements CryptoSocket { channelRead(); break; case NEED_WORK: - if (authorizationResult == null) { - PeerAuthorizerTrustManager.getAuthorizationResult(sslEngine) // only available during handshake - .ifPresent(result -> { - if (!result.succeeded()) { - metrics.incrementPeerAuthorizationFailures(); - } - authorizationResult = result; - }); - } break; case COMPLETED: return HandshakeState.COMPLETED; @@ -127,6 +114,10 @@ public class TlsCryptoSocket implements CryptoSocket { SSLSession session = sslEngine.getSession(); sessionApplicationBufferSize = session.getApplicationBufferSize(); sessionPacketBufferSize = session.getPacketBufferSize(); + authContext = TransportSecurityUtils.getConnectionAuthContext(session).orElseThrow(); + if (!authContext.authorized()) { + metrics.incrementPeerAuthorizationFailures(); + } log.fine(() -> String.format("Handshake complete: protocol=%s, cipherSuite=%s", session.getProtocol(), session.getCipherSuite())); if (sslEngine.getUseClientMode()) { metrics.incrementClientTlsConnectionsEstablished(); @@ -148,8 +139,7 @@ public class TlsCryptoSocket implements CryptoSocket { } } } catch (SSLHandshakeException e) { - // sslEngine.getDelegatedTask().run() and handshakeWrap() may throw SSLHandshakeException, potentially handshakeUnwrap() and sslEngine.beginHandshake() as well. - if (authorizationResult == null || authorizationResult.succeeded()) { // don't include handshake failures due from PeerAuthorizerTrustManager + if (!(e.getCause() instanceof PeerAuthorizationFailedException)) { metrics.incrementTlsCertificateVerificationFailures(); } throw e; @@ -224,19 +214,9 @@ public class TlsCryptoSocket implements CryptoSocket { } @Override - public Optional<SecurityContext> getSecurityContext() { - try { - if (handshakeState != HandshakeState.COMPLETED) { - return Optional.empty(); - } - List<X509Certificate> peerCertificateChain = - Arrays.stream(sslEngine.getSession().getPeerCertificates()) - .map(X509Certificate.class::cast) - .collect(toList()); - return Optional.of(new SecurityContext(peerCertificateChain)); - } catch (SSLPeerUnverifiedException e) { // unverified peer: non-certificate based ciphers or peer did not provide a certificate - return Optional.of(new SecurityContext(List.of())); // secure connection, but peer does not have a certificate chain. - } + public ConnectionAuthContext connectionAuthContext() { + if (handshakeState != HandshakeState.COMPLETED) throw new IllegalStateException("Handshake not complete"); + return Objects.requireNonNull(authContext); } private boolean handshakeWrap() throws IOException { diff --git a/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java b/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java index b37e78490f1..dd2fd5a8242 100644 --- a/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java +++ b/jrt/src/com/yahoo/jrt/slobrok/api/Mirror.java @@ -11,6 +11,7 @@ import com.yahoo.jrt.Task; import com.yahoo.jrt.TransportThread; import com.yahoo.jrt.Values; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -203,7 +204,7 @@ public class Mirror implements IMirror { req = new Request("slobrok.incremental.fetch"); req.parameters().add(new Int32Value(specsGeneration)); // gencnt req.parameters().add(new Int32Value(5000)); // mstimeout - target.invokeAsync(req, 40.0, reqWait); + target.invokeAsync(req, Duration.ofSeconds(40), reqWait); } private void handleUpdate() { diff --git a/jrt/src/com/yahoo/jrt/slobrok/api/Register.java b/jrt/src/com/yahoo/jrt/slobrok/api/Register.java index 14afea396bf..e529dea2eff 100644 --- a/jrt/src/com/yahoo/jrt/slobrok/api/Register.java +++ b/jrt/src/com/yahoo/jrt/slobrok/api/Register.java @@ -15,6 +15,7 @@ import com.yahoo.jrt.Task; import com.yahoo.jrt.TransportThread; import com.yahoo.jrt.Values; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -279,7 +280,7 @@ public class Register { req.parameters().add(new StringValue(name)); req.parameters().add(new StringValue(mySpec)); log.log(Level.FINE, logMessagePrefix() + " now"); - target.invokeAsync(req, 35.0, reqWait); + target.invokeAsync(req, Duration.ofSeconds(35), reqWait); } private String logMessagePrefix() { diff --git a/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java b/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java index 24ab63c1d2f..5fd8beb3cc7 100644 --- a/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java +++ b/jrt/src/com/yahoo/jrt/slobrok/server/Slobrok.java @@ -17,6 +17,7 @@ import com.yahoo.jrt.TargetWatcher; import com.yahoo.jrt.Task; import com.yahoo.jrt.Transport; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -225,7 +226,7 @@ public class Slobrok { this.spec = spec; target = orb.connect(new Spec(spec)); Request cbReq = new Request("slobrok.callback.listNamesServed"); - target.invokeAsync(cbReq, 5.0, this); + target.invokeAsync(cbReq, Duration.ofSeconds(5), this); } @Override diff --git a/jrt/src/com/yahoo/jrt/tool/RpcInvoker.java b/jrt/src/com/yahoo/jrt/tool/RpcInvoker.java index 71049673d90..67933cfafde 100644 --- a/jrt/src/com/yahoo/jrt/tool/RpcInvoker.java +++ b/jrt/src/com/yahoo/jrt/tool/RpcInvoker.java @@ -16,6 +16,7 @@ import com.yahoo.jrt.Transport; import com.yahoo.jrt.Value; import com.yahoo.jrt.Values; +import java.time.Duration; import java.util.Arrays; import java.util.List; import java.util.ArrayList; @@ -80,7 +81,7 @@ public class RpcInvoker { supervisor = new Supervisor(new Transport("invoker")); target = supervisor.connect(new Spec(connectspec)); Request request = createRequest(method,arguments); - target.invokeSync(request,10.0); + target.invokeSync(request, Duration.ofSeconds(10)); if (request.isError()) { System.err.println("error(" + request.errorCode() + "): " + request.errorMessage()); return; diff --git a/jrt/tests/com/yahoo/jrt/AbortTest.java b/jrt/tests/com/yahoo/jrt/AbortTest.java index 2f31b3a52f6..df1f207458e 100644 --- a/jrt/tests/com/yahoo/jrt/AbortTest.java +++ b/jrt/tests/com/yahoo/jrt/AbortTest.java @@ -4,6 +4,8 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -44,7 +46,7 @@ public class AbortTest { Test.Waiter w = new Test.Waiter(); Request req = new Request("test"); req.parameters().add(new Int32Value(20)); - target.invokeAsync(req, 5.0, w); + target.invokeAsync(req, Duration.ofSeconds(5), w); req.abort(); barrier.breakIt(); w.waitDone(); @@ -54,7 +56,7 @@ public class AbortTest { Request req2 = new Request("test"); req2.parameters().add(new Int32Value(30)); - target.invokeSync(req2, 5.0); + target.invokeSync(req2, Duration.ofSeconds(5)); assertTrue(!req2.isError()); assertEquals(1, req2.returnValues().size()); assertEquals(30, req2.returnValues().get(0).asInt32()); diff --git a/jrt/tests/com/yahoo/jrt/BackTargetTest.java b/jrt/tests/com/yahoo/jrt/BackTargetTest.java index a55a6d7f474..5b9e7ccb157 100644 --- a/jrt/tests/com/yahoo/jrt/BackTargetTest.java +++ b/jrt/tests/com/yahoo/jrt/BackTargetTest.java @@ -4,6 +4,8 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -81,24 +83,24 @@ public class BackTargetTest { @org.junit.Test public void testBackTarget() { checkTargets(false, false); - target.invokeSync(new Request("sample_target"), 5.0); + target.invokeSync(new Request("sample_target"), Duration.ofSeconds(5)); checkTargets(true, false); - serverBackTarget.invokeSync(new Request("sample_target"), 5.0); + serverBackTarget.invokeSync(new Request("sample_target"), Duration.ofSeconds(5)); checkTargets(true, true); checkValues(0, 0); - target.invokeSync(new Request("inc"), 5.0); + target.invokeSync(new Request("inc"), Duration.ofSeconds(5)); checkValues(1, 0); - serverBackTarget.invokeSync(new Request("inc"), 5.0); + serverBackTarget.invokeSync(new Request("inc"), Duration.ofSeconds(5)); checkValues(1, 1); - clientBackTarget.invokeSync(new Request("inc"), 5.0); + clientBackTarget.invokeSync(new Request("inc"), Duration.ofSeconds(5)); checkValues(2, 1); - target.invokeSync(new Request("back_inc"), 5.0); + target.invokeSync(new Request("back_inc"), Duration.ofSeconds(5)); checkValues(2, 2); - serverBackTarget.invokeSync(new Request("back_inc"), 5.0); + serverBackTarget.invokeSync(new Request("back_inc"), Duration.ofSeconds(5)); checkValues(3, 2); - clientBackTarget.invokeSync(new Request("back_inc"), 5.0); + clientBackTarget.invokeSync(new Request("back_inc"), Duration.ofSeconds(5)); checkValues(3, 3); } diff --git a/jrt/tests/com/yahoo/jrt/CryptoUtils.java b/jrt/tests/com/yahoo/jrt/CryptoUtils.java index f1672f86e9b..cef138ffba1 100644 --- a/jrt/tests/com/yahoo/jrt/CryptoUtils.java +++ b/jrt/tests/com/yahoo/jrt/CryptoUtils.java @@ -4,15 +4,14 @@ package com.yahoo.jrt; import com.yahoo.security.KeyUtils; import com.yahoo.security.X509CertificateBuilder; import com.yahoo.security.tls.AuthorizationMode; +import com.yahoo.security.tls.AuthorizedPeers; import com.yahoo.security.tls.DefaultTlsContext; import com.yahoo.security.tls.HostnameVerification; import com.yahoo.security.tls.PeerAuthentication; +import com.yahoo.security.tls.PeerPolicy; +import com.yahoo.security.tls.RequiredPeerCredential; +import com.yahoo.security.tls.RequiredPeerCredential.Field; import com.yahoo.security.tls.TlsContext; -import com.yahoo.security.tls.policy.AuthorizedPeers; -import com.yahoo.security.tls.policy.PeerPolicy; -import com.yahoo.security.tls.policy.RequiredPeerCredential; -import com.yahoo.security.tls.policy.RequiredPeerCredential.Field; -import com.yahoo.security.tls.policy.Role; import javax.security.auth.x500.X500Principal; import java.security.KeyPair; @@ -42,8 +41,6 @@ class CryptoUtils { singleton( new PeerPolicy( "localhost-policy", - singleton( - new Role("localhost-role")), singletonList( RequiredPeerCredential.of(Field.CN, "localhost"))))); diff --git a/jrt/tests/com/yahoo/jrt/DetachTest.java b/jrt/tests/com/yahoo/jrt/DetachTest.java index 3c3356b53e2..4c3ee085913 100644 --- a/jrt/tests/com/yahoo/jrt/DetachTest.java +++ b/jrt/tests/com/yahoo/jrt/DetachTest.java @@ -4,6 +4,8 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -66,11 +68,11 @@ public class DetachTest { Test.Waiter w1 = new Test.Waiter(); Request req1 = new Request("d_inc"); req1.parameters().add(new Int32Value(50)); - target.invokeAsync(req1, 5.0, w1); + target.invokeAsync(req1, Duration.ofSeconds(5), w1); Request req2 = new Request("d_inc_r"); req2.parameters().add(new Int32Value(60)); - target.invokeSync(req2, 5.0); + target.invokeSync(req2, Duration.ofSeconds(5)); assertTrue(!req2.isError()); assertEquals(1, req2.returnValues().size()); @@ -123,7 +125,7 @@ public class DetachTest { Test.Waiter w = new Test.Waiter(); Request req3 = new Request("inc_b"); req3.parameters().add(new Int32Value(100)); - target.invokeAsync(req3, 5.0, w); + target.invokeAsync(req3, Duration.ofSeconds(5), w); Request blocked = (Request) receptor.get(); try { blocked.returnRequest(); diff --git a/jrt/tests/com/yahoo/jrt/EchoTest.java b/jrt/tests/com/yahoo/jrt/EchoTest.java index 26d4315fad6..11742fa42e2 100644 --- a/jrt/tests/com/yahoo/jrt/EchoTest.java +++ b/jrt/tests/com/yahoo/jrt/EchoTest.java @@ -2,6 +2,7 @@ package com.yahoo.jrt; +import com.yahoo.security.tls.ConnectionAuthContext; import org.junit.After; import org.junit.Before; import org.junit.runner.RunWith; @@ -10,12 +11,12 @@ import org.junit.runners.Parameterized.Parameter; import org.junit.runners.Parameterized.Parameters; import java.security.cert.X509Certificate; +import java.time.Duration; import java.util.List; import static com.yahoo.jrt.CryptoUtils.createTestTlsContext; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; @RunWith(Parameterized.class) @@ -28,19 +29,19 @@ public class EchoTest { Supervisor client; Target target; Values refValues; - SecurityContext securityContext; + ConnectionAuthContext connAuthCtx; private interface MetricsAssertions { void assertMetrics(TransportMetrics.Snapshot snapshot) throws AssertionError; } - private interface SecurityContextAssertion { - void assertSecurityContext(SecurityContext securityContext) throws AssertionError; + private interface ConnectionAuthContextAssertion { + void assertConnectionAuthContext(ConnectionAuthContext authContext) throws AssertionError; } @Parameter(value = 0) public CryptoEngine crypto; @Parameter(value = 1) public MetricsAssertions metricsAssertions; - @Parameter(value = 2) public SecurityContextAssertion securityContextAssertion; + @Parameter(value = 2) public ConnectionAuthContextAssertion connAuthCtxAssertion; @Parameters(name = "{0}") public static Object[] engines() { @@ -62,7 +63,7 @@ public class EchoTest { assertEquals(1, metrics.serverTlsConnectionsEstablished()); assertEquals(1, metrics.clientTlsConnectionsEstablished()); }, - (SecurityContextAssertion) context -> { + (ConnectionAuthContextAssertion) context -> { List<X509Certificate> chain = context.peerCertificateChain(); assertEquals(1, chain.size()); assertEquals(CryptoUtils.certificate, chain.get(0)); @@ -80,7 +81,7 @@ public class EchoTest { assertEquals(1, metrics.serverTlsConnectionsEstablished()); assertEquals(1, metrics.clientTlsConnectionsEstablished()); }, - (SecurityContextAssertion) context -> { + (ConnectionAuthContextAssertion) context -> { List<X509Certificate> chain = context.peerCertificateChain(); assertEquals(1, chain.size()); assertEquals(CryptoUtils.certificate, chain.get(0)); @@ -146,7 +147,7 @@ public class EchoTest { for (int i = 0; i < p.size(); i++) { r.add(p.get(i)); } - securityContext = req.target().getSecurityContext().orElse(null); + connAuthCtx = req.target().connectionAuthContext(); } @org.junit.Test @@ -156,7 +157,7 @@ public class EchoTest { for (int i = 0; i < refValues.size(); i++) { p.add(refValues.get(i)); } - target.invokeSync(req, 60.0); + target.invokeSync(req, Duration.ofSeconds(60)); assertTrue(req.checkReturnTypes("bBhHiIlLfFdDxXsS")); assertTrue(Test.equals(req.returnValues(), req.parameters())); assertTrue(Test.equals(req.returnValues(), refValues)); @@ -164,11 +165,9 @@ public class EchoTest { if (metricsAssertions != null) { metricsAssertions.assertMetrics(metrics.snapshot().changesSince(startSnapshot)); } - if (securityContextAssertion != null) { - assertNotNull(securityContext); - securityContextAssertion.assertSecurityContext(securityContext); - } else { - assertNull(securityContext); + if (connAuthCtxAssertion != null) { + assertNotNull(connAuthCtx); + connAuthCtxAssertion.assertConnectionAuthContext(connAuthCtx); } } } diff --git a/jrt/tests/com/yahoo/jrt/InvokeAsyncTest.java b/jrt/tests/com/yahoo/jrt/InvokeAsyncTest.java index 5e9f426bb17..e17b6c0cfdd 100644 --- a/jrt/tests/com/yahoo/jrt/InvokeAsyncTest.java +++ b/jrt/tests/com/yahoo/jrt/InvokeAsyncTest.java @@ -5,6 +5,8 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -16,6 +18,7 @@ public class InvokeAsyncTest { Supervisor client; Target target; Test.Barrier barrier; + SimpleRequestAccessFilter filter; @Before public void setUp() throws ListenFailedException { @@ -23,11 +26,13 @@ public class InvokeAsyncTest { client = new Supervisor(new Transport()); acceptor = server.listen(new Spec(0)); target = client.connect(new Spec("localhost", acceptor.port())); + filter = new SimpleRequestAccessFilter(); server.addMethod(new Method("concat", "ss", "s", this::rpc_concat) .methodDesc("Concatenate 2 strings") .paramDesc(0, "str1", "a string") .paramDesc(1, "str2", "another string") - .returnDesc(0, "ret", "str1 followed by str2")); + .returnDesc(0, "ret", "str1 followed by str2") + .requestAccessFilter(filter)); barrier = new Test.Barrier(); } @@ -54,7 +59,7 @@ public class InvokeAsyncTest { req.parameters().add(new StringValue("def")); Test.Waiter w = new Test.Waiter(); - target.invokeAsync(req, 5.0, w); + target.invokeAsync(req, Duration.ofSeconds(5), w); assertFalse(w.isDone()); barrier.breakIt(); w.waitDone(); @@ -65,4 +70,21 @@ public class InvokeAsyncTest { assertEquals("abcdef", req.returnValues().get(0).asString()); } + @org.junit.Test + public void testFilterIsInvoked() { + Request req = new Request("concat"); + req.parameters().add(new StringValue("abc")); + req.parameters().add(new StringValue("def")); + assertFalse(filter.invoked); + Test.Waiter w = new Test.Waiter(); + target.invokeAsync(req, Duration.ofSeconds(10), w); + assertFalse(w.isDone()); + barrier.breakIt(); + w.waitDone(); + assertTrue(w.isDone()); + assertFalse(req.isError()); + assertEquals("abcdef", req.returnValues().get(0).asString()); + assertTrue(filter.invoked); + } + } diff --git a/jrt/tests/com/yahoo/jrt/InvokeErrorTest.java b/jrt/tests/com/yahoo/jrt/InvokeErrorTest.java index a9a0b18b5a1..0b75fe713c2 100644 --- a/jrt/tests/com/yahoo/jrt/InvokeErrorTest.java +++ b/jrt/tests/com/yahoo/jrt/InvokeErrorTest.java @@ -5,17 +5,22 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; + import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; public class InvokeErrorTest { - final double timeout=60.0; + final Duration timeout = Duration.ofSeconds(60); Supervisor server; Acceptor acceptor; Supervisor client; Target target; Test.Barrier barrier; + SimpleRequestAccessFilter filter; + RpcTestMethod testMethod; @Before public void setUp() throws ListenFailedException { @@ -23,7 +28,9 @@ public class InvokeErrorTest { client = new Supervisor(new Transport()); acceptor = server.listen(new Spec(0)); target = client.connect(new Spec("localhost", acceptor.port())); - server.addMethod(new Method("test", "iib", "i", this::rpc_test)); + filter = new SimpleRequestAccessFilter(); + testMethod = new RpcTestMethod(); + server.addMethod(new Method("test", "iib", "i", testMethod).requestAccessFilter(filter)); server.addMethod(new Method("test_barrier", "iib", "i", this::rpc_test_barrier)); barrier = new Test.Barrier(); } @@ -36,22 +43,8 @@ public class InvokeErrorTest { server.transport().shutdown().join(); } - private void rpc_test(Request req) { - int value = req.parameters().get(0).asInt32(); - int error = req.parameters().get(1).asInt32(); - int extra = req.parameters().get(2).asInt8(); - - req.returnValues().add(new Int32Value(value)); - if (extra != 0) { - req.returnValues().add(new Int32Value(value)); - } - if (error != 0) { - req.setError(error, "Custom error"); - } - } - private void rpc_test_barrier(Request req) { - rpc_test(req); + testMethod.invoke(req); barrier.waitFor(); } @@ -157,4 +150,40 @@ public class InvokeErrorTest { assertEquals(ErrorCode.CONNECTION, req1.errorCode()); } + @org.junit.Test + public void testFilterFailsRequest() { + Request r = new Request("test"); + r.parameters().add(new Int32Value(42)); + r.parameters().add(new Int32Value(0)); + r.parameters().add(new Int8Value((byte)0)); + filter.allowed = false; + assertFalse(filter.invoked); + target.invokeSync(r, timeout); + assertTrue(r.isError()); + assertTrue(filter.invoked); + assertFalse(testMethod.invoked); + assertEquals(ErrorCode.PERMISSION_DENIED, r.errorCode()); + assertEquals("Permission denied", r.errorMessage()); + } + + private static class RpcTestMethod implements MethodHandler { + boolean invoked = false; + + @Override public void invoke(Request req) { invoked = true; rpc_test(req); } + + void rpc_test(Request req) { + int value = req.parameters().get(0).asInt32(); + int error = req.parameters().get(1).asInt32(); + int extra = req.parameters().get(2).asInt8(); + + req.returnValues().add(new Int32Value(value)); + if (extra != 0) { + req.returnValues().add(new Int32Value(value)); + } + if (error != 0) { + req.setError(error, "Custom error"); + } + } + } + } diff --git a/jrt/tests/com/yahoo/jrt/InvokeSyncTest.java b/jrt/tests/com/yahoo/jrt/InvokeSyncTest.java index ca7d0db129d..ff44017e1bc 100644 --- a/jrt/tests/com/yahoo/jrt/InvokeSyncTest.java +++ b/jrt/tests/com/yahoo/jrt/InvokeSyncTest.java @@ -10,8 +10,10 @@ import java.io.FileDescriptor; import java.io.FileOutputStream; import java.io.IOException; import java.io.PrintStream; +import java.time.Duration; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -21,6 +23,7 @@ public class InvokeSyncTest { Acceptor acceptor; Supervisor client; Target target; + SimpleRequestAccessFilter filter; @Before public void setUp() throws ListenFailedException { @@ -28,11 +31,13 @@ public class InvokeSyncTest { client = new Supervisor(new Transport()); acceptor = server.listen(new Spec(0)); target = client.connect(new Spec("localhost", acceptor.port())); + filter = new SimpleRequestAccessFilter(); server.addMethod(new Method("concat", "ss", "s", this::rpc_concat) .methodDesc("Concatenate 2 strings") .paramDesc(0, "str1", "a string") .paramDesc(1, "str2", "another string") - .returnDesc(0, "ret", "str1 followed by str2")); + .returnDesc(0, "ret", "str1 followed by str2") + .requestAccessFilter(filter)); server.addMethod(new Method("alltypes", "bhilfds", "s", this::rpc_alltypes) .methodDesc("Method taking all types of params")); } @@ -63,7 +68,7 @@ public class InvokeSyncTest { req.parameters().add(new StringValue("abc")); req.parameters().add(new StringValue("def")); - target.invokeSync(req, 5.0); + target.invokeSync(req, Duration.ofSeconds(5)); assertTrue(!req.isError()); assertEquals(1, req.returnValues().size()); @@ -84,4 +89,17 @@ public class InvokeSyncTest { assertEquals(baos.toString(), "This was alltypes. The string param was: baz\n"); } + @org.junit.Test + public void testFilterIsInvoked() { + Request req = new Request("concat"); + req.parameters().add(new StringValue("abc")); + req.parameters().add(new StringValue("def")); + assertFalse(filter.invoked); + target.invokeSync(req, Duration.ofSeconds(10)); + assertFalse(req.isError()); + assertEquals("abcdef", req.returnValues().get(0).asString()); + assertTrue(filter.invoked); + } + + } diff --git a/jrt/tests/com/yahoo/jrt/InvokeVoidTest.java b/jrt/tests/com/yahoo/jrt/InvokeVoidTest.java index 8b674136fe2..64c3bc91371 100644 --- a/jrt/tests/com/yahoo/jrt/InvokeVoidTest.java +++ b/jrt/tests/com/yahoo/jrt/InvokeVoidTest.java @@ -5,6 +5,8 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -56,7 +58,7 @@ public class InvokeVoidTest { public void testInvokeVoid() { Request req = new Request("set"); req.parameters().add(new Int32Value(40)); - target.invokeSync(req, 5.0); + target.invokeSync(req, Duration.ofSeconds(5)); assertTrue(!req.isError()); assertEquals(0, req.returnValues().size()); @@ -64,7 +66,7 @@ public class InvokeVoidTest { target.invokeVoid(new Request("inc")); req = new Request("get"); - target.invokeSync(req, 5.0); + target.invokeSync(req, Duration.ofSeconds(5)); assertTrue(!req.isError()); assertEquals(42, req.returnValues().get(0).asInt32()); diff --git a/jrt/tests/com/yahoo/jrt/LatencyTest.java b/jrt/tests/com/yahoo/jrt/LatencyTest.java index 0df15ed400b..945833e51a8 100644 --- a/jrt/tests/com/yahoo/jrt/LatencyTest.java +++ b/jrt/tests/com/yahoo/jrt/LatencyTest.java @@ -2,6 +2,7 @@ package com.yahoo.jrt; +import java.time.Duration; import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.logging.Logger; @@ -116,7 +117,7 @@ public class LatencyTest { } Request req = new Request("inc"); req.parameters().add(new Int32Value(value)); - target.invokeSync(req, 60.0); + target.invokeSync(req, Duration.ofSeconds(60)); long duration = System.nanoTime() - t; assertTrue(req.checkReturnTypes("i")); assertEquals(value + 1, req.returnValues().get(0).asInt32()); diff --git a/jrt/tests/com/yahoo/jrt/MandatoryMethodsTest.java b/jrt/tests/com/yahoo/jrt/MandatoryMethodsTest.java index 212447dd6da..c0ef9606b1f 100644 --- a/jrt/tests/com/yahoo/jrt/MandatoryMethodsTest.java +++ b/jrt/tests/com/yahoo/jrt/MandatoryMethodsTest.java @@ -5,6 +5,7 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; import java.util.HashSet; import static org.junit.Assert.assertEquals; @@ -38,7 +39,7 @@ public class MandatoryMethodsTest { @org.junit.Test public void testPing() { Request req = new Request("frt.rpc.ping"); - target.invokeSync(req, 5.0); + target.invokeSync(req, Duration.ofSeconds(5)); assertFalse(req.isError()); assertEquals(0, req.returnValues().size()); @@ -47,7 +48,7 @@ public class MandatoryMethodsTest { @org.junit.Test public void testGetMethodList() { Request req = new Request("frt.rpc.getMethodList"); - target.invokeSync(req, 5.0); + target.invokeSync(req, Duration.ofSeconds(5)); assertFalse(req.isError()); assertTrue(req.checkReturnTypes("SSS")); @@ -81,7 +82,7 @@ public class MandatoryMethodsTest { public void testGetMethodInfo() { Request req = new Request("frt.rpc.getMethodInfo"); req.parameters().add(new StringValue("frt.rpc.getMethodInfo")); - target.invokeSync(req, 5.0); + target.invokeSync(req, Duration.ofSeconds(5)); assertFalse(req.isError()); assertTrue(req.checkReturnTypes("sssSSSS")); diff --git a/jrt/tests/com/yahoo/jrt/SimpleRequestAccessFilter.java b/jrt/tests/com/yahoo/jrt/SimpleRequestAccessFilter.java new file mode 100644 index 00000000000..38d59720848 --- /dev/null +++ b/jrt/tests/com/yahoo/jrt/SimpleRequestAccessFilter.java @@ -0,0 +1,9 @@ +package com.yahoo.jrt;// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +/** + * @author bjorncs + */ +class SimpleRequestAccessFilter implements RequestAccessFilter { + volatile boolean invoked = false, allowed = true; + @Override public boolean allow(Request r) { invoked = true; return allowed; } +} diff --git a/jrt/tests/com/yahoo/jrt/TimeoutTest.java b/jrt/tests/com/yahoo/jrt/TimeoutTest.java index 0366020b221..1a802758e60 100644 --- a/jrt/tests/com/yahoo/jrt/TimeoutTest.java +++ b/jrt/tests/com/yahoo/jrt/TimeoutTest.java @@ -5,6 +5,8 @@ package com.yahoo.jrt; import org.junit.After; import org.junit.Before; +import java.time.Duration; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -52,11 +54,11 @@ public class TimeoutTest { req.parameters().add(new StringValue("abc")); req.parameters().add(new StringValue("def")); - target.invokeSync(req, 0.1); + target.invokeSync(req, Duration.ofMillis(100)); barrier.breakIt(); Request flush = new Request("frt.rpc.ping"); - target.invokeSync(flush, 5.0); + target.invokeSync(flush, Duration.ofSeconds(5)); assertTrue(!flush.isError()); assertTrue(req.isError()); @@ -72,7 +74,7 @@ public class TimeoutTest { req.parameters().add(new StringValue("def")); Test.Waiter w = new Test.Waiter(); - target.invokeAsync(req, 30.0, w); + target.invokeAsync(req, Duration.ofSeconds(30), w); try { Thread.sleep(2500); } catch (InterruptedException e) {} barrier.breakIt(); w.waitDone(); |