Avoid NPE in ForwardClientSocketClientHandler & code refactor (#1150)

This commit is contained in:
华华 2020-05-26 14:55:03 +08:00 committed by GitHub
parent 3a12237d05
commit 8178599836
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,3 @@
package com.alibaba.arthas.tunnel.client; package com.alibaba.arthas.tunnel.client;
import java.net.URI; import java.net.URI;
@ -12,7 +11,6 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup;
@ -30,18 +28,13 @@ import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.GenericFutureListener;
/** /**
*
* @author hengyunabc 2019-08-28 * @author hengyunabc 2019-08-28
*
*/ */
public class ForwardClientSocketClientHandler extends SimpleChannelInboundHandler<WebSocketFrame> { public class ForwardClientSocketClientHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
private final static Logger logger = LoggerFactory.getLogger(ForwardClientSocketClientHandler.class);
private ChannelPromise handshakeFuture; private static final Logger logger = LoggerFactory.getLogger(ForwardClientSocketClientHandler.class);
private Channel localChannel; private final URI localServerURI;
private URI localServerURI;
public ForwardClientSocketClientHandler(URI localServerURI) { public ForwardClientSocketClientHandler(URI localServerURI) {
this.localServerURI = localServerURI; this.localServerURI = localServerURI;
@ -49,7 +42,6 @@ public class ForwardClientSocketClientHandler extends SimpleChannelInboundHandle
@Override @Override
public void channelActive(ChannelHandlerContext ctx) { public void channelActive(ChannelHandlerContext ctx) {
} }
@Override @Override
@ -58,26 +50,30 @@ public class ForwardClientSocketClientHandler extends SimpleChannelInboundHandle
} }
@Override @Override
public void userEventTriggered(final ChannelHandlerContext ctx, Object evt) throws Exception { public void userEventTriggered(final ChannelHandlerContext ctx, Object evt) {
if (evt.equals(ClientHandshakeStateEvent.HANDSHAKE_COMPLETE)) { if (evt.equals(ClientHandshakeStateEvent.HANDSHAKE_COMPLETE)) {
EventLoopGroup group = new NioEventLoopGroup();
try { try {
connectLocalServer(ctx);
} catch (Throwable e) {
logger.error("ForwardClientSocketClientHandler connect local arthas server error", e);
}
} else {
ctx.fireUserEventTriggered(evt);
}
}
private void connectLocalServer(final ChannelHandlerContext ctx) throws InterruptedException {
EventLoopGroup group = new NioEventLoopGroup();
logger.info("ForwardClientSocketClientHandler star connect local arthas server"); logger.info("ForwardClientSocketClientHandler star connect local arthas server");
WebSocketClientHandshaker newHandshaker = WebSocketClientHandshakerFactory.newHandshaker(localServerURI, WebSocketClientHandshaker newHandshaker = WebSocketClientHandshakerFactory.newHandshaker(localServerURI,
WebSocketVersion.V13, null, true, new DefaultHttpHeaders()); WebSocketVersion.V13, null, true, new DefaultHttpHeaders());
final WebSocketClientProtocolHandler websocketClientHandler = new WebSocketClientProtocolHandler( final WebSocketClientProtocolHandler websocketClientHandler = new WebSocketClientProtocolHandler(
newHandshaker); newHandshaker);
final LocalFrameHandler localFrameHandler = new LocalFrameHandler(); final LocalFrameHandler localFrameHandler = new LocalFrameHandler();
Bootstrap b = new Bootstrap(); Bootstrap b = new Bootstrap();
b.group(group).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() { b.group(group).channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override @Override
protected void initChannel(SocketChannel ch) { protected void initChannel(SocketChannel ch) {
ChannelPipeline p = ch.pipeline(); ChannelPipeline p = ch.pipeline();
@ -86,45 +82,30 @@ public class ForwardClientSocketClientHandler extends SimpleChannelInboundHandle
} }
}); });
localChannel = b.connect(localServerURI.getHost(), localServerURI.getPort()).sync().channel(); Channel localChannel = b.connect(localServerURI.getHost(), localServerURI.getPort()).sync().channel();
localFrameHandler.handshakeFuture()
localFrameHandler.handshakeFuture().addListener(new GenericFutureListener<ChannelFuture>() { .addListener(new GenericFutureListener<ChannelFuture>() {
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
ChannelPipeline pipeline = future.channel().pipeline(); ChannelPipeline pipeline = future.channel().pipeline();
pipeline.remove(localFrameHandler); pipeline.remove(localFrameHandler);
pipeline.addLast(new RelayHandler(ctx.channel())); pipeline.addLast(new RelayHandler(ctx.channel()));
} }
}); });
localFrameHandler.handshakeFuture().sync(); localFrameHandler.handshakeFuture().sync();
ctx.pipeline().remove(ForwardClientSocketClientHandler.this); ctx.pipeline().remove(ForwardClientSocketClientHandler.this);
ctx.pipeline().addLast(new RelayHandler(localChannel)); ctx.pipeline().addLast(new RelayHandler(localChannel));
logger.info("ForwardClientSocketClientHandler connect local arthas server success"); logger.info("ForwardClientSocketClientHandler connect local arthas server success");
} catch (Throwable e) {
logger.error("ForwardClientSocketClientHandler connect local arthas server error", e);
}
} else {
ctx.fireUserEventTriggered(evt);
}
} }
@Override @Override
protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception { protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame msg) {
} }
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace(); cause.printStackTrace();
if (!handshakeFuture.isDone()) {
handshakeFuture.setFailure(cause);
}
ctx.close(); ctx.close();
} }
} }