diff --git a/reactor-netty-core/src/main/java/reactor/netty/channel/AbstractChannelMetricsHandler.java b/reactor-netty-core/src/main/java/reactor/netty/channel/AbstractChannelMetricsHandler.java index ee14dbf8df..ce2b44e405 100644 --- a/reactor-netty-core/src/main/java/reactor/netty/channel/AbstractChannelMetricsHandler.java +++ b/reactor-netty-core/src/main/java/reactor/netty/channel/AbstractChannelMetricsHandler.java @@ -21,6 +21,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.socket.DatagramPacket; +import io.netty.handler.ssl.AbstractSniHandler; +import io.netty.handler.ssl.SslHandler; import reactor.netty.NettyPipeline; import reactor.util.Logger; import reactor.util.Loggers; @@ -95,12 +97,19 @@ public void channelRegistered(ChannelHandlerContext ctx) { NettyPipeline.ConnectMetricsHandler, connectMetricsHandler()); } - if (ctx.pipeline().get(NettyPipeline.SslHandler) != null) { + ChannelHandler sslHandler = ctx.pipeline().get(NettyPipeline.SslHandler); + if (sslHandler instanceof SslHandler) { ctx.pipeline() .addBefore(NettyPipeline.SslHandler, NettyPipeline.TlsMetricsHandler, tlsMetricsHandler()); } + else if (sslHandler instanceof AbstractSniHandler) { + ctx.pipeline() + .addAfter(NettyPipeline.SslHandler, + NettyPipeline.TlsMetricsHandler, + tlsMetricsHandler()); + } ctx.fireChannelRegistered(); } diff --git a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelMetricsHandler.java b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelMetricsHandler.java index 9b424ff9dc..e708de1dd7 100644 --- a/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelMetricsHandler.java +++ b/reactor-netty-core/src/main/java/reactor/netty/channel/ChannelMetricsHandler.java @@ -20,6 +20,7 @@ import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; +import io.netty.handler.ssl.SniCompletionEvent; import io.netty.handler.ssl.SslHandler; import reactor.util.annotation.Nullable; @@ -86,28 +87,46 @@ static class TlsMetricsHandler extends ChannelInboundHandlerAdapter { protected final ChannelMetricsRecorder recorder; + boolean listenerAdded; + TlsMetricsHandler(ChannelMetricsRecorder recorder) { this.recorder = recorder; } @Override public void channelActive(ChannelHandlerContext ctx) { - long tlsHandshakeTimeStart = System.nanoTime(); - ctx.pipeline() - .get(SslHandler.class) - .handshakeFuture() - .addListener(f -> { - ctx.pipeline().remove(this); - recordTlsHandshakeTime(ctx, tlsHandshakeTimeStart, f.isSuccess() ? SUCCESS : ERROR); - }); + addListener(ctx); ctx.fireChannelActive(); } + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + addListener(ctx); + } + ctx.fireUserEventTriggered(evt); + } + protected void recordTlsHandshakeTime(ChannelHandlerContext ctx, long tlsHandshakeTimeStart, String status) { recorder.recordTlsHandshakeTime( ctx.channel().remoteAddress(), Duration.ofNanos(System.nanoTime() - tlsHandshakeTimeStart), status); } + + private void addListener(ChannelHandlerContext ctx) { + if (!listenerAdded) { + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler != null) { + listenerAdded = true; + long tlsHandshakeTimeStart = System.nanoTime(); + sslHandler.handshakeFuture() + .addListener(f -> { + ctx.pipeline().remove(this); + recordTlsHandshakeTime(ctx, tlsHandshakeTimeStart, f.isSuccess() ? SUCCESS : ERROR); + }); + } + } + } } } diff --git a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java index 30f2850c75..30138d7c75 100644 --- a/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java +++ b/reactor-netty-http/src/test/java/reactor/netty/http/server/HttpServerTests.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011-2023 VMware, Inc. or its affiliates, All Rights Reserved. + * Copyright (c) 2011-2024 VMware, Inc. or its affiliates, All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -136,6 +136,7 @@ import reactor.netty.http.Http2SslContextSpec; import reactor.netty.http.HttpProtocol; import reactor.netty.http.client.HttpClient; +import reactor.netty.http.client.HttpClientMetricsRecorder; import reactor.netty.http.client.HttpClientRequest; import reactor.netty.http.client.PrematureCloseException; import reactor.netty.http.logging.ReactorNettyHttpMessageLogFactory; @@ -2124,6 +2125,21 @@ void testHang() { @Test void testSniSupport() throws Exception { + doTestSniSupport(Function.identity(), Function.identity()); + } + + @Test + void testIssue3022() throws Exception { + TestHttpClientMetricsRecorder clientMetricsRecorder = new TestHttpClientMetricsRecorder(); + TestHttpServerMetricsRecorder serverMetricsRecorder = new TestHttpServerMetricsRecorder(); + doTestSniSupport(server -> server.metrics(true, () -> serverMetricsRecorder, Function.identity()), + client -> client.metrics(true, () -> clientMetricsRecorder, Function.identity())); + assertThat(clientMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO); + assertThat(serverMetricsRecorder.tlsHandshakeTime).isNotNull().isGreaterThan(Duration.ZERO); + } + + private void doTestSniSupport(Function serverCustomizer, + Function clientCustomizer) throws Exception { SelfSignedCertificate defaultCert = new SelfSignedCertificate("default"); Http11SslContextSpec defaultSslContextBuilder = Http11SslContextSpec.forServer(defaultCert.certificate(), defaultCert.privateKey()); @@ -2138,7 +2154,7 @@ void testSniSupport() throws Exception { AtomicReference hostname = new AtomicReference<>(); disposableServer = - createServer() + serverCustomizer.apply(createServer()) .secure(spec -> spec.sslContext(defaultSslContextBuilder) .addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(testSslContextBuilder))) .doOnChannelInit((obs, channel, remoteAddress) -> @@ -2155,7 +2171,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { .handle((req, res) -> res.sendString(Mono.just("testSniSupport"))) .bindNow(); - createClient(disposableServer::address) + clientCustomizer.apply(createClient(disposableServer::address)) .secure(spec -> spec.sslContext(clientSslContextBuilder) .serverNames(new SNIHostName("test.com"))) .get() @@ -3569,4 +3585,112 @@ private void testIssue2927(Function serverCustomizer, Fu .expectErrorMatches(t -> t instanceof PrematureCloseException && t.getCause() instanceof Http2Exception.HeaderListSizeException) .verify(Duration.ofSeconds(30)); } + + static final class TestHttpServerMetricsRecorder implements HttpServerMetricsRecorder { + + Duration tlsHandshakeTime; + + @Override + public void recordDataReceived(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress) { + } + + @Override + public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) { + tlsHandshakeTime = time; + } + + @Override + public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) { + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress, String uri) { + } + + @Override + public void recordDataReceivedTime(String uri, String method, Duration time) { + } + + @Override + public void recordDataSentTime(String uri, String method, String status, Duration time) { + } + + @Override + public void recordResponseTime(String uri, String method, String status, Duration time) { + } + } + + static final class TestHttpClientMetricsRecorder implements HttpClientMetricsRecorder { + + Duration tlsHandshakeTime; + + @Override + public void recordDataReceived(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, long bytes) { + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress) { + } + + @Override + public void recordTlsHandshakeTime(SocketAddress remoteAddress, Duration time, String status) { + tlsHandshakeTime = time; + } + + @Override + public void recordConnectTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordResolveAddressTime(SocketAddress remoteAddress, Duration time, String status) { + } + + @Override + public void recordDataReceived(SocketAddress remoteAddress, String uri, long bytes) { + } + + @Override + public void recordDataSent(SocketAddress remoteAddress, String uri, long bytes) { + } + + @Override + public void incrementErrorsCount(SocketAddress remoteAddress, String uri) { + } + + @Override + public void recordDataReceivedTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + + @Override + public void recordDataSentTime(SocketAddress remoteAddress, String uri, String method, Duration time) { + } + + @Override + public void recordResponseTime(SocketAddress remoteAddress, String uri, String method, String status, Duration time) { + } + } }