Netty简单应用 RPC
目录
netty简单应用
创建server模板
@Slf4j
public class RpcServer {
public static void main(String[] args) {
NioEventLoopGroup boss = new NioEventLoopGroup();
NioEventLoopGroup work = new NioEventLoopGroup();
LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
MessageCodecSharable messageCodecSharable = new MessageCodecSharable();
try{
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.channel(NioServerSocketChannel.class);
serverBootstrap.group(boss, work);
// 配置最大连接数
serverBootstrap.option(ChannelOption.SO_BACKLOG, 2);
// 参数配置
serverBootstrap.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.SO_RCVBUF, 8192 * 128)
.childOption(ChannelOption.SO_SNDBUF, 8192 * 128);
serverBootstrap.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
ch.pipeline().addLast(new ProtocolFrameDecoder())
.addLast(loggingHandler)
.addLast(messageCodecSharable);
}
});
ChannelFuture channelFuture = serverBootstrap.bind(8000).sync();
channelFuture.channel().closeFuture().sync();
}catch (Exception e){
log.error("server error:", e);
}finally {
boss.shutdownGracefully();
work.shutdownGracefully();
}
}
}
创建client模板
@Slf4j
public class RpcClient {
public static void main(String[] args) {
NioEventLoopGroup work = new NioEventLoopGroup();
MessageCodecSharable messageCodecSharable = new MessageCodecSharable();
LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
try{
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(work);
bootstrap.channel(NioSocketChannel.class);
// 连接超时配置
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 3000);
bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
bootstrap.option(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
bootstrap.handler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
ch.pipeline().addLast(new ProtocolFrameDecoder())
.addLast(loggingHandler)
.addLast(messageCodecSharable);
}
});
ChannelFuture channelFuture = bootstrap.connect("localhost", 8000).sync();
RpcRequestMessage rpcRequestMessage = new RpcRequestMessage(1,
"cn.com.wuhm.netty.netty.rpc.ISimpleRpcFacade",
"say",
String.class,
new Class[]{String.class},
new Object[]{"test rpc"});
channelFuture.channel().writeAndFlush(rpcRequestMessage);
channelFuture.channel().closeFuture().sync();
}catch (Exception e){
log.error("client error: ", e);
}finally {
work.shutdownGracefully();
}
}
}
简单模拟RPC调用
创建RpcRequestMessage
@Getter
@ToString(callSuper = true)
public class RpcRequestMessage extends Message {
/**
* 调用的接口全限定名,服务端根据它找到实现
*/
private String interfaceName;
/**
* 调用接口中的方法名
*/
private String methodName;
/**
* 方法返回类型
*/
private Class<?> returnType;
/**
* 方法参数类型数组
*/
private Class[] parameterTypes;
/**
* 方法参数值数组
*/
private Object[] parameterValue;
public RpcRequestMessage(int sequenceId, String interfaceName, String methodName, Class<?> returnType, Class[] parameterTypes, Object[] parameterValue) {
super.setSequenceId(sequenceId);
this.interfaceName = interfaceName;
this.methodName = methodName;
this.returnType = returnType;
this.parameterTypes = parameterTypes;
this.parameterValue = parameterValue;
}
@Override
public int getMessageType() {
return RPC_MESSAGE_TYPE_REQUEST;
}
}
创建RpcResponse
@Data
@ToString(callSuper = true)
public class RpcResponseMessage extends Message {
/**
* 返回值
*/
private Object returnValue;
/**
* 异常值
*/
private Exception exceptionValue;
@Override
public int getMessageType() {
return RPC_MESSAGE_TYPE_RESPONSE;
}
}
创建服务端处理请求handler
@Slf4j
@ChannelHandler.Sharable
public class RpcRequestHandle extends SimpleChannelInboundHandler<RpcRequestMessage> {
@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcRequestMessage rpcRequestMessage) throws Exception {
RpcResponseMessage rpcResponseMessage = new RpcResponseMessage();
try {
ISimpleRpcFacade instance = (ISimpleRpcFacade) ServicesFactory.getInstance(Class.forName(rpcRequestMessage.getInterfaceName()));
Method method = instance.getClass().getMethod(rpcRequestMessage.getMethodName(), rpcRequestMessage.getParameterTypes());
Object invoked = method.invoke(instance, rpcRequestMessage.getParameterValue());
rpcResponseMessage.setReturnValue(invoked);
} catch (Exception e) {
rpcResponseMessage.setExceptionValue(e);
throw new RuntimeException(e);
}
rpcResponseMessage.setSequenceId(rpcRequestMessage.getSequenceId()); ctx.writeAndFlush(rpcResponseMessage);
}
}
创建客户端处理相应handler
@Slf4j
@ChannelHandler.Sharable
public class RpcResponseHandle extends SimpleChannelInboundHandler<RpcResponseMessage> {
@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcResponseMessage msg) throws Exception {
log.info("response message: {}", msg.toString());
// 当我们填充完消息后需要将这个promise删除掉,删除之前会将这个value返回
Promise<Object> promise = PromiseContainer.PROMISE_MAP.remove(msg.getSequenceId());
if(promise != null){
Object returnValue = msg.getReturnValue();
Exception exceptionValue = msg.getExceptionValue();
if(exceptionValue == null){
promise.setSuccess(returnValue);
}else{
promise.setFailure(exceptionValue);
}
}
}
}
创建promise容器,
要在不同线程中进行返回值的传递,需要用到Promise
public class PromiseContainer {
/**
*
*/
public static Map<Integer, Promise<Object>> PROMISE_MAP = new ConcurrentHashMap<>();
}
创建rpc server
@Slf4j
public class RpcServer {
public static void main(String[] args) {
NioEventLoopGroup boss = new NioEventLoopGroup();
NioEventLoopGroup work = new NioEventLoopGroup();
LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
MessageCodecSharable messageCodecSharable = new MessageCodecSharable();
RpcRequestHandle rpcRequestHandle = new RpcRequestHandle();
try{
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.channel(NioServerSocketChannel.class);
serverBootstrap.group(boss, work);
// 配置最大连接数
serverBootstrap.option(ChannelOption.SO_BACKLOG, 2);
// 参数配置
serverBootstrap.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.SO_RCVBUF, 8192 * 128)
.childOption(ChannelOption.SO_SNDBUF, 8192 * 128);
serverBootstrap.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
ch.pipeline().addLast(new ProcotolFrameDecoder())
.addLast(loggingHandler)
.addLast(messageCodecSharable)
.addLast(rpcRequestHandle);
}
});
ChannelFuture channelFuture = serverBootstrap.bind(8000).sync();
channelFuture.channel().closeFuture().sync();
}catch (Exception e){
log.error("server error:", e);
}finally {
boss.shutdownGracefully();
work.shutdownGracefully();
}
}
}
创建rpc client manager
主要是用来建立连接发起远程调用和接收返回结果
@Slf4j
public class RpcClientManager {
private static Channel channel = null;
private static final Object LOCK = new Object();
/**
* 产生SequenceId
*/
private static final AtomicInteger SEQUENCE_ID = new AtomicInteger(0);
public static <T> T getProxyService(Class<T> serviceClass){
ClassLoader classLoader = serviceClass.getClassLoader();
Class<?>[] classes = new Class[]{serviceClass};
Object o = Proxy.newProxyInstance(classLoader, classes, (proxy, method, args) -> {
// 1. 将方法调用转换为消息对象
int id = SEQUENCE_ID.getAndIncrement();
RpcRequestMessage rpcRequestMessage = new RpcRequestMessage(
id,
serviceClass.getName(),
method.getName(),
method.getReturnType(),
method.getParameterTypes(),
args);
// 2. 远程调用
getChannel().writeAndFlush(rpcRequestMessage);
// 3. 返回调用结果
// 3.1 使用Promise来接收结果,getChannel().eventLoop()指定promise异步接收结果的线程(也就是调用listener中的逻辑的时候,由哪个线程来执行),为每个请求初始化一个promise用来接收方法调用的结果
DefaultPromise<Object> promise = new DefaultPromise<>(getChannel().eventLoop());
PromiseContainer.PROMISE_MAP.put(id, promise);
// 3.2需要同步阻塞等到响应结果
promise.await();
if(promise.isSuccess()){
// 调用成功
return promise.getNow();
}else{
// 远程调用失败
throw new RuntimeException(promise.cause());
}
});
return (T) o;
}
public static Channel getChannel(){
if(channel != null){
return channel;
}
// 单例模式:防止多线程并发的问题
synchronized (LOCK) {
if(channel != null){
return channel;
}
initChannel();
return channel;
}
}
/**
* 初始化channel
*/
private static void initChannel(){
NioEventLoopGroup work = new NioEventLoopGroup();
MessageCodecSharable messageCodecSharable = new MessageCodecSharable();
LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
RpcResponseHandle rpcResponseHandle = new RpcResponseHandle();
try{
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(work);
bootstrap.channel(NioSocketChannel.class);
// 连接超时配置
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 3000);
bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
bootstrap.option(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
bootstrap.handler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
ch.pipeline().addLast(new ProtocolFrameDecoder())
.addLast(loggingHandler)
.addLast(messageCodecSharable)
.addLast(rpcResponseHandle);
}
});
channel = bootstrap.connect("localhost", 8000).sync().channel();
// 这里不能使用阻塞的方式:否则channel需要等待关闭后才能返回,但是这个时候channel已经是关闭状态已经没有意义了
// channel.closeFuture().sync();
channel.closeFuture().addListener((future)->{
work.shutdownGracefully();
});
}catch (Exception e){
log.error("client error: ", e);
}
}
}
获得Channel
- 建立连接,获取Channel的操作被封装到了
init
方法中,当连接断开时,通过addListener
方法异步关闭group - 通过单例模式创建与获取Channel
远程调用方法
- 为了让方法的调用变得简洁明了,将
RpcRequestMessage
的创建与发送过程通过JDK的动态代理来完成 - 通过返回的代理对象调用方法即可,方法参数为被调用方法接口的Class类
远程调用方法返回值获取
-
调用方法的是主线程,处理返回结果的是NIO线程(RpcResponseMessageHandler)。要在不同线程中进行返回值的传递,需要用到Promise
-
在
RpcResponseMessageHandler
中创建一个Map- Key为SequenceId
- Value为对应的Promise
-
主线程的代理类将RpcResponseMessage发送给服务器后,需要创建Promise对象,并将其放入到RpcResponseMessageHandler的Map中。需要使用await等待结果被放入Promise中。获取结果后,根据结果类型(判断是否成功)来返回结果或抛出异常
// 3.1 使用Promise来接收结果,getChannel().eventLoop()指定promise异步接收结果的线程(也就是调用listener中的逻辑的时候,由哪个线程来执行),为每个请求初始化一个promise用来接收方法调用的结果 DefaultPromise<Object> promise = new DefaultPromise<>(getChannel().eventLoop()); PromiseContainer.PROMISE_MAP.put(id, promise); // 3.2需要同步阻塞等到响应结果 promise.await(); if(promise.isSuccess()){ // 调用成功 return promise.getNow(); }else{ // 远程调用失败 throw new RuntimeException(promise.cause()); }
-
NIO线程负责通过SequenceId**获取并移除(remove)**对应的Promise,然后根据RpcResponseMessage中的结果,向Promise中放入不同的值
- 如果没有异常信息(ExceptionValue),就调用
promise.setSuccess(returnValue)
放入方法返回值 - 如果有异常信息,就调用
promise.setFailure(exception)
放入异常信息
// 将返回结果放入对应的Promise中,并移除Map中的Promise // 当我们填充完消息后需要将这个promise删除掉,删除之前会将这个value返回 Promise<Object> promise = PromiseContainer.PROMISE_MAP.remove(msg.getSequenceId()); if(promise != null){ Object returnValue = msg.getReturnValue(); Exception exceptionValue = msg.getExceptionValue(); if(exceptionValue == null){ promise.setSuccess(returnValue); }else{ promise.setFailure(exceptionValue); } }
- 如果没有异常信息(ExceptionValue),就调用
创建一个测试接口
public interface ISimpleRpcFacade {
String say(String msg);
}
实现类:
public class SimpleRpcFacadeImpl implements ISimpleRpcFacade {
@Override
public String say(String msg) {
return "rpc response msg: " + msg ;
}
}
通过本地模拟spring环境
public class ServicesFactory {
static HashMap<Class<?>, Object> map = new HashMap<>(16);
public static Object getInstance(Class<?> interfaceClass) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
// 根据Class创建实例
try {
Class<?> clazz = Class.forName("cn.com.wuhm.netty.netty.rpc.ISimpleRpcFacade");
Object instance = Class.forName("cn.com.wuhm.netty.netty.rpc.impl.SimpleRpcFacadeImpl").newInstance();
// 放入 InterfaceClass -> InstanceObject 的映射
map.put(clazz, instance);
} catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
e.printStackTrace();
}
return map.get(interfaceClass);
}
}
创建rpc client 测试消息
@Slf4j
public class RpcClient {
public static void main(String[] args) {
// 创建代理对象
ISimpleRpcFacade service = RpcClientManager.getProxyService(ISimpleRpcFacade.class);
// 通过代理对象执行方法
System.out.println(service.say("netty"));
System.out.println(service.say("rpc test"));
}
}
结果
11:33:35 [INFO ] [nioEventLoopGroup-2-1] c.c.w.n.n.r.h.RpcResponseHandle - response message: RpcResponseMessage(super=Message(sequenceId=0, messageType=102), returnValue=rpc response msg: netty, exceptionValue=null)
rpc response msg: netty
11:33:35 [INFO ] [nioEventLoopGroup-2-1] c.c.w.n.n.r.h.RpcResponseHandle - response message: RpcResponseMessage(super=Message(sequenceId=1, messageType=102), returnValue=rpc response msg: rpc test, exceptionValue=null)
rpc response msg: rpc test
其他类:
Message
@Data
public abstract class Message implements Serializable {
/**
* 根据消息类型字节,获得对应的消息 class
* @param messageType 消息类型字节
* @return 消息 class
*/
public static Class<? extends Message> getMessageClass(int messageType) {
return messageClasses.get(messageType);
}
private int sequenceId;
private int messageType;
public abstract int getMessageType();
public static final int LoginRequestMessage = 0;
public static final int LoginResponseMessage = 1;
public static final int ChatRequestMessage = 2;
public static final int ChatResponseMessage = 3;
public static final int GroupCreateRequestMessage = 4;
public static final int GroupCreateResponseMessage = 5;
public static final int GroupJoinRequestMessage = 6;
public static final int GroupJoinResponseMessage = 7;
public static final int GroupQuitRequestMessage = 8;
public static final int GroupQuitResponseMessage = 9;
public static final int GroupChatRequestMessage = 10;
public static final int GroupChatResponseMessage = 11;
public static final int GroupMembersRequestMessage = 12;
public static final int GroupMembersResponseMessage = 13;
public static final int PingMessage = 14;
public static final int PongMessage = 15;
/**
* 请求类型 byte 值
*/
public static final int RPC_MESSAGE_TYPE_REQUEST = 101;
/**
* 响应类型 byte 值
*/
public static final int RPC_MESSAGE_TYPE_RESPONSE = 102;
private static final Map<Integer, Class<? extends Message>> messageClasses = new HashMap<>();
static {
messageClasses.put(LoginRequestMessage, LoginRequestMessage.class);
messageClasses.put(LoginResponseMessage, LoginResponseMessage.class);
messageClasses.put(ChatRequestMessage, ChatRequestMessage.class);
messageClasses.put(ChatResponseMessage, ChatResponseMessage.class);
messageClasses.put(GroupCreateRequestMessage, GroupCreateRequestMessage.class);
messageClasses.put(GroupCreateResponseMessage, GroupCreateResponseMessage.class);
messageClasses.put(GroupJoinRequestMessage, GroupJoinRequestMessage.class);
messageClasses.put(GroupJoinResponseMessage, GroupJoinResponseMessage.class);
messageClasses.put(GroupQuitRequestMessage, GroupQuitRequestMessage.class);
messageClasses.put(GroupQuitResponseMessage, GroupQuitResponseMessage.class);
messageClasses.put(GroupChatRequestMessage, GroupChatRequestMessage.class);
messageClasses.put(GroupChatResponseMessage, GroupChatResponseMessage.class);
messageClasses.put(GroupMembersRequestMessage, GroupMembersRequestMessage.class);
messageClasses.put(GroupMembersResponseMessage, GroupMembersResponseMessage.class);
messageClasses.put(RPC_MESSAGE_TYPE_REQUEST, RpcRequestMessage.class);
messageClasses.put(RPC_MESSAGE_TYPE_RESPONSE, RpcResponseMessage.class);
}
}
MessageCodecSharable
@Slf4j
@ChannelHandler.Sharable
public class MessageCodecSharable extends MessageToMessageCodec<ByteBuf, Message> {
@Override
protected void encode(ChannelHandlerContext ctx, Message msg, List<Object> out) throws Exception {
ByteBuf byteBuf = ctx.alloc().buffer();
// 1. 4个字节的魔数
byteBuf.writeBytes(new byte[]{1, 2, 3, 4});
// 2. 1个字节的版本号
byteBuf.writeByte(1);
// 3. 1字节序列化方式 0-jdk;1-json
byteBuf.writeByte(0);
// 4. 1字节指令类型
byteBuf.writeByte(msg.getMessageType());
// 5. 4字节请求
byteBuf.writeInt(msg.getSequenceId());
byteBuf.writeByte(0xff);
// 6. 正文长度
ByteArrayOutputStream bs = new ByteArrayOutputStream();
ObjectOutputStream os = new ObjectOutputStream(bs);
os.writeObject(msg);
byte[] bytes = bs.toByteArray();
byteBuf.writeInt(bytes.length);
// 7. 正文内容
byteBuf.writeBytes(bytes);
out.add(byteBuf);
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List<Object> out) throws Exception {
int magicNum = byteBuf.readInt();
byte version = byteBuf.readByte();
byte serializer = byteBuf.readByte();
byte messageType = byteBuf.readByte();
int sequenceId = byteBuf.readInt();
byteBuf.readByte();
int length = byteBuf.readInt();
byte[] bytes = new byte[length];
byteBuf.readBytes(bytes, 0, length);
// 反序列化对象
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes));
Message message = (Message) ois.readObject();
// log.info("{} {} {} {} {} {}", magicNum, version, serializer, messageType, sequenceId, length);
// log.info(message.toString());
out.add(message);
}
}
ProtocolFrameDecoder
public class ProtocolFrameDecoder extends LengthFieldBasedFrameDecoder {
public ProtocolFrameDecoder(){
this(8 * 1024 * 1024, 12, 4, 0, 0);
}
public ProtocolFrameDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, int lengthAdjustment, int initialBytesToStrip) {
super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
}
}