Netty简单应用 RPC

netty简单应用

创建server模板

@Slf4j
public class RpcServer {

    public static void main(String[] args) {
        NioEventLoopGroup boss = new NioEventLoopGroup();
        NioEventLoopGroup work = new NioEventLoopGroup();
        LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
        MessageCodecSharable messageCodecSharable = new MessageCodecSharable();

        try{
            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap.channel(NioServerSocketChannel.class);
            serverBootstrap.group(boss, work);
            // 配置最大连接数
            serverBootstrap.option(ChannelOption.SO_BACKLOG, 2);

            // 参数配置
            serverBootstrap.childOption(ChannelOption.TCP_NODELAY, true)
                    .childOption(ChannelOption.SO_KEEPALIVE, true)
                    .childOption(ChannelOption.SO_RCVBUF, 8192 * 128)
                    .childOption(ChannelOption.SO_SNDBUF, 8192 * 128);

            serverBootstrap.childHandler(new ChannelInitializer<NioSocketChannel>() {
                @Override
                protected void initChannel(NioSocketChannel ch) throws Exception {
                    ch.pipeline().addLast(new ProtocolFrameDecoder())
                            .addLast(loggingHandler)
                            .addLast(messageCodecSharable);
                }
            });

            ChannelFuture channelFuture = serverBootstrap.bind(8000).sync();
            channelFuture.channel().closeFuture().sync();

        }catch (Exception e){
            log.error("server error:", e);
        }finally {
            boss.shutdownGracefully();
            work.shutdownGracefully();
        }
    }
}

创建client模板

@Slf4j
public class RpcClient {

    public static void main(String[] args) {

        NioEventLoopGroup work = new NioEventLoopGroup();
        MessageCodecSharable messageCodecSharable = new MessageCodecSharable();
        LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
        
        try{
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.group(work);
            bootstrap.channel(NioSocketChannel.class);
            // 连接超时配置
            bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 3000);
            bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
            bootstrap.option(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
            bootstrap.handler(new ChannelInitializer<NioSocketChannel>() {
                @Override
                protected void initChannel(NioSocketChannel ch) throws Exception {
                    ch.pipeline().addLast(new ProtocolFrameDecoder())
                            .addLast(loggingHandler)
                            .addLast(messageCodecSharable);
                }
            });
            ChannelFuture channelFuture = bootstrap.connect("localhost", 8000).sync();
            RpcRequestMessage rpcRequestMessage = new RpcRequestMessage(1,
                "cn.com.wuhm.netty.netty.rpc.ISimpleRpcFacade",
                "say",
                String.class,
                new Class[]{String.class},
                new Object[]{"test rpc"});
            channelFuture.channel().writeAndFlush(rpcRequestMessage);
            channelFuture.channel().closeFuture().sync();

        }catch (Exception e){
            log.error("client error: ", e);
        }finally {
            work.shutdownGracefully();
        }
    }
}

简单模拟RPC调用

创建RpcRequestMessage

@Getter
@ToString(callSuper = true)
public class RpcRequestMessage extends Message {

    /**
     * 调用的接口全限定名,服务端根据它找到实现
     */
    private String interfaceName;
    /**
     * 调用接口中的方法名
     */
    private String methodName;
    /**
     * 方法返回类型
     */
    private Class<?> returnType;
    /**
     * 方法参数类型数组
     */
    private Class[] parameterTypes;
    /**
     * 方法参数值数组
     */
    private Object[] parameterValue;

    public RpcRequestMessage(int sequenceId, String interfaceName, String methodName, Class<?> returnType, Class[] parameterTypes, Object[] parameterValue) {
        super.setSequenceId(sequenceId);
        this.interfaceName = interfaceName;
        this.methodName = methodName;
        this.returnType = returnType;
        this.parameterTypes = parameterTypes;
        this.parameterValue = parameterValue;
    }

    @Override
    public int getMessageType() {
        return RPC_MESSAGE_TYPE_REQUEST;
    }
}

创建RpcResponse

@Data
@ToString(callSuper = true)
public class RpcResponseMessage extends Message {
    /**
     * 返回值
     */
    private Object returnValue;
    /**
     * 异常值
     */
    private Exception exceptionValue;

    @Override
    public int getMessageType() {
        return RPC_MESSAGE_TYPE_RESPONSE;
    }
}

创建服务端处理请求handler

@Slf4j
@ChannelHandler.Sharable
public class RpcRequestHandle extends SimpleChannelInboundHandler<RpcRequestMessage> {
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcRequestMessage rpcRequestMessage) throws Exception {
        RpcResponseMessage rpcResponseMessage = new RpcResponseMessage();
        try {
            ISimpleRpcFacade instance = (ISimpleRpcFacade) ServicesFactory.getInstance(Class.forName(rpcRequestMessage.getInterfaceName()));

            Method method = instance.getClass().getMethod(rpcRequestMessage.getMethodName(), rpcRequestMessage.getParameterTypes());
            Object invoked = method.invoke(instance, rpcRequestMessage.getParameterValue());

            rpcResponseMessage.setReturnValue(invoked);
        } catch (Exception e) {
            rpcResponseMessage.setExceptionValue(e);
            throw new RuntimeException(e);
        }
 rpcResponseMessage.setSequenceId(rpcRequestMessage.getSequenceId());       ctx.writeAndFlush(rpcResponseMessage);
    }
}

创建客户端处理相应handler

@Slf4j
@ChannelHandler.Sharable
public class RpcResponseHandle extends SimpleChannelInboundHandler<RpcResponseMessage> {
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcResponseMessage msg) throws Exception {
        log.info("response message: {}", msg.toString());
        // 当我们填充完消息后需要将这个promise删除掉,删除之前会将这个value返回
        Promise<Object> promise = PromiseContainer.PROMISE_MAP.remove(msg.getSequenceId());
        if(promise != null){
            Object returnValue = msg.getReturnValue();
            Exception exceptionValue = msg.getExceptionValue();
            if(exceptionValue == null){
                promise.setSuccess(returnValue);
            }else{
                promise.setFailure(exceptionValue);
            }
        }
    }
}

创建promise容器,

要在不同线程中进行返回值的传递,需要用到Promise

public class PromiseContainer {

    /**
     *
     */
    public static Map<Integer, Promise<Object>> PROMISE_MAP = new ConcurrentHashMap<>();


}

创建rpc server

@Slf4j
public class RpcServer {

    public static void main(String[] args) {
        NioEventLoopGroup boss = new NioEventLoopGroup();
        NioEventLoopGroup work = new NioEventLoopGroup();
        LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
        MessageCodecSharable messageCodecSharable = new MessageCodecSharable();
        RpcRequestHandle rpcRequestHandle = new RpcRequestHandle();
        try{
            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap.channel(NioServerSocketChannel.class);
            serverBootstrap.group(boss, work);
            // 配置最大连接数
            serverBootstrap.option(ChannelOption.SO_BACKLOG, 2);

            // 参数配置
            serverBootstrap.childOption(ChannelOption.TCP_NODELAY, true)
                    .childOption(ChannelOption.SO_KEEPALIVE, true)
                    .childOption(ChannelOption.SO_RCVBUF, 8192 * 128)
                    .childOption(ChannelOption.SO_SNDBUF, 8192 * 128);

            serverBootstrap.childHandler(new ChannelInitializer<NioSocketChannel>() {
                @Override
                protected void initChannel(NioSocketChannel ch) throws Exception {
                    ch.pipeline().addLast(new ProcotolFrameDecoder())
                            .addLast(loggingHandler)
                            .addLast(messageCodecSharable)
                            .addLast(rpcRequestHandle);
                }
            });

            ChannelFuture channelFuture = serverBootstrap.bind(8000).sync();
            channelFuture.channel().closeFuture().sync();

        }catch (Exception e){
            log.error("server error:", e);
        }finally {
            boss.shutdownGracefully();
            work.shutdownGracefully();
        }
    }
}

创建rpc client manager

主要是用来建立连接发起远程调用和接收返回结果

@Slf4j
public class RpcClientManager {

    private static Channel channel = null;

    private static final Object LOCK = new Object();

    /**
     * 产生SequenceId
     */
    private static final AtomicInteger SEQUENCE_ID = new AtomicInteger(0);

    public static <T> T getProxyService(Class<T> serviceClass){
        ClassLoader classLoader = serviceClass.getClassLoader();
        Class<?>[] classes = new Class[]{serviceClass};
        Object o = Proxy.newProxyInstance(classLoader, classes, (proxy, method, args) -> {
            // 1. 将方法调用转换为消息对象
            int id = SEQUENCE_ID.getAndIncrement();
            RpcRequestMessage rpcRequestMessage = new RpcRequestMessage(
                    id,
                    serviceClass.getName(),
                    method.getName(),
                    method.getReturnType(),
                    method.getParameterTypes(),
                    args);
            // 2. 远程调用
            getChannel().writeAndFlush(rpcRequestMessage);

            // 3. 返回调用结果
            // 3.1 使用Promise来接收结果,getChannel().eventLoop()指定promise异步接收结果的线程(也就是调用listener中的逻辑的时候,由哪个线程来执行),为每个请求初始化一个promise用来接收方法调用的结果
            DefaultPromise<Object> promise = new DefaultPromise<>(getChannel().eventLoop());
            PromiseContainer.PROMISE_MAP.put(id, promise);

            // 3.2需要同步阻塞等到响应结果
            promise.await();
            if(promise.isSuccess()){
                // 调用成功
                return promise.getNow();
            }else{
                // 远程调用失败
                throw new RuntimeException(promise.cause());
            }
        });
        return (T) o;
    }

    public static Channel getChannel(){
        if(channel != null){
            return channel;
        }
        // 单例模式:防止多线程并发的问题
        synchronized (LOCK) {
            if(channel != null){
                return channel;
            }
            initChannel();
            return channel;
        }
    }

    /**
     * 初始化channel
     */
    private static void initChannel(){
        NioEventLoopGroup work = new NioEventLoopGroup();
        MessageCodecSharable messageCodecSharable = new MessageCodecSharable();
        LoggingHandler loggingHandler = new LoggingHandler(LogLevel.DEBUG);
        RpcResponseHandle rpcResponseHandle = new RpcResponseHandle();
        try{
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.group(work);
            bootstrap.channel(NioSocketChannel.class);
            // 连接超时配置
            bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 3000);
            bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
            bootstrap.option(ChannelOption.ALLOCATOR, ByteBufAllocator.DEFAULT);
            bootstrap.handler(new ChannelInitializer<NioSocketChannel>() {
                @Override
                protected void initChannel(NioSocketChannel ch) throws Exception {
                    ch.pipeline().addLast(new ProtocolFrameDecoder())
                            .addLast(loggingHandler)
                            .addLast(messageCodecSharable)
                            .addLast(rpcResponseHandle);
                }
            });
            channel = bootstrap.connect("localhost", 8000).sync().channel();
            // 这里不能使用阻塞的方式:否则channel需要等待关闭后才能返回,但是这个时候channel已经是关闭状态已经没有意义了
            // channel.closeFuture().sync();
            channel.closeFuture().addListener((future)->{
                work.shutdownGracefully();
            });

        }catch (Exception e){
            log.error("client error: ", e);
        }
    }
}

获得Channel

  • 建立连接,获取Channel的操作被封装到了init方法中,当连接断开时,通过addListener法异步关闭group
  • 通过单例模式创建与获取Channel

远程调用方法

  • 为了让方法的调用变得简洁明了,将RpcRequestMessage创建与发送过程通过JDK的动态代理来完成
  • 通过返回的代理对象调用方法即可,方法参数为被调用方法接口的Class类

远程调用方法返回值获取

  • 调用方法的是主线程,处理返回结果的是NIO线程(RpcResponseMessageHandler)。要在不同线程中进行返回值的传递,需要用到Promise

  • RpcResponseMessageHandler中创建一个Map

    • Key为SequenceId
    • Value为对应的Promise
  • 主线程的代理类将RpcResponseMessage发送给服务器后,需要创建Promise对象,并将其放入到RpcResponseMessageHandler的Map中。需要使用await等待结果被放入Promise中。获取结果后,根据结果类型(判断是否成功)来返回结果或抛出异常

    // 3.1 使用Promise来接收结果,getChannel().eventLoop()指定promise异步接收结果的线程(也就是调用listener中的逻辑的时候,由哪个线程来执行),为每个请求初始化一个promise用来接收方法调用的结果
                DefaultPromise<Object> promise = new DefaultPromise<>(getChannel().eventLoop());
                PromiseContainer.PROMISE_MAP.put(id, promise);
    
                // 3.2需要同步阻塞等到响应结果
                promise.await();
                if(promise.isSuccess()){
                    // 调用成功
                    return promise.getNow();
                }else{
                    // 远程调用失败
                    throw new RuntimeException(promise.cause());
                }
  • NIO线程负责通过SequenceId**获取并移除(remove)**对应的Promise,然后根据RpcResponseMessage中的结果,向Promise中放入不同的值

    • 如果没有异常信息(ExceptionValue),就调用promise.setSuccess(returnValue)放入方法返回值
    • 如果有异常信息,就调用promise.setFailure(exception)放入异常信息
    // 将返回结果放入对应的Promise中,并移除Map中的Promise
    // 当我们填充完消息后需要将这个promise删除掉,删除之前会将这个value返回
            Promise<Object> promise = PromiseContainer.PROMISE_MAP.remove(msg.getSequenceId());
            if(promise != null){
                Object returnValue = msg.getReturnValue();
                Exception exceptionValue = msg.getExceptionValue();
                if(exceptionValue == null){
                    promise.setSuccess(returnValue);
                }else{
                    promise.setFailure(exceptionValue);
                }
            }

创建一个测试接口

public interface ISimpleRpcFacade {

    String say(String msg);

}

实现类:

public class SimpleRpcFacadeImpl implements ISimpleRpcFacade {
    @Override
    public String say(String msg) {
        return "rpc response msg: " + msg ;
    }
}

通过本地模拟spring环境

public class ServicesFactory {
    static HashMap<Class<?>, Object> map = new HashMap<>(16);

    public static Object getInstance(Class<?> interfaceClass) throws ClassNotFoundException, IllegalAccessException, InstantiationException {
        // 根据Class创建实例
        try {
            Class<?> clazz = Class.forName("cn.com.wuhm.netty.netty.rpc.ISimpleRpcFacade");
            Object instance = Class.forName("cn.com.wuhm.netty.netty.rpc.impl.SimpleRpcFacadeImpl").newInstance();
           
            // 放入 InterfaceClass -> InstanceObject 的映射
            map.put(clazz, instance);
        } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
            e.printStackTrace();
        }  
        return map.get(interfaceClass);
    }
}

创建rpc client 测试消息

@Slf4j
public class RpcClient {

    public static void main(String[] args) {
        // 创建代理对象
        ISimpleRpcFacade service = RpcClientManager.getProxyService(ISimpleRpcFacade.class);
        // 通过代理对象执行方法
        System.out.println(service.say("netty"));
        System.out.println(service.say("rpc test"));
    }
}

结果

11:33:35 [INFO ] [nioEventLoopGroup-2-1] c.c.w.n.n.r.h.RpcResponseHandle - response message: RpcResponseMessage(super=Message(sequenceId=0, messageType=102), returnValue=rpc response msg: netty, exceptionValue=null)
rpc response msg: netty
11:33:35 [INFO ] [nioEventLoopGroup-2-1] c.c.w.n.n.r.h.RpcResponseHandle - response message: RpcResponseMessage(super=Message(sequenceId=1, messageType=102), returnValue=rpc response msg: rpc test, exceptionValue=null)
rpc response msg: rpc test

其他类:

Message

@Data
public abstract class Message implements Serializable {

    /**
     * 根据消息类型字节,获得对应的消息 class
     * @param messageType 消息类型字节
     * @return 消息 class
     */
    public static Class<? extends Message> getMessageClass(int messageType) {
        return messageClasses.get(messageType);
    }

    private int sequenceId;

    private int messageType;

    public abstract int getMessageType();

    public static final int LoginRequestMessage = 0;
    public static final int LoginResponseMessage = 1;
    public static final int ChatRequestMessage = 2;
    public static final int ChatResponseMessage = 3;
    public static final int GroupCreateRequestMessage = 4;
    public static final int GroupCreateResponseMessage = 5;
    public static final int GroupJoinRequestMessage = 6;
    public static final int GroupJoinResponseMessage = 7;
    public static final int GroupQuitRequestMessage = 8;
    public static final int GroupQuitResponseMessage = 9;
    public static final int GroupChatRequestMessage = 10;
    public static final int GroupChatResponseMessage = 11;
    public static final int GroupMembersRequestMessage = 12;
    public static final int GroupMembersResponseMessage = 13;
    public static final int PingMessage = 14;
    public static final int PongMessage = 15;
    /**
     * 请求类型 byte 值
     */
    public static final int RPC_MESSAGE_TYPE_REQUEST = 101;
    /**
     * 响应类型 byte 值
     */
    public static final int  RPC_MESSAGE_TYPE_RESPONSE = 102;

    private static final Map<Integer, Class<? extends Message>> messageClasses = new HashMap<>();

    static {
        messageClasses.put(LoginRequestMessage, LoginRequestMessage.class);
        messageClasses.put(LoginResponseMessage, LoginResponseMessage.class);
        messageClasses.put(ChatRequestMessage, ChatRequestMessage.class);
        messageClasses.put(ChatResponseMessage, ChatResponseMessage.class);
        messageClasses.put(GroupCreateRequestMessage, GroupCreateRequestMessage.class);
        messageClasses.put(GroupCreateResponseMessage, GroupCreateResponseMessage.class);
        messageClasses.put(GroupJoinRequestMessage, GroupJoinRequestMessage.class);
        messageClasses.put(GroupJoinResponseMessage, GroupJoinResponseMessage.class);
        messageClasses.put(GroupQuitRequestMessage, GroupQuitRequestMessage.class);
        messageClasses.put(GroupQuitResponseMessage, GroupQuitResponseMessage.class);
        messageClasses.put(GroupChatRequestMessage, GroupChatRequestMessage.class);
        messageClasses.put(GroupChatResponseMessage, GroupChatResponseMessage.class);
        messageClasses.put(GroupMembersRequestMessage, GroupMembersRequestMessage.class);
        messageClasses.put(GroupMembersResponseMessage, GroupMembersResponseMessage.class);
        messageClasses.put(RPC_MESSAGE_TYPE_REQUEST, RpcRequestMessage.class);
        messageClasses.put(RPC_MESSAGE_TYPE_RESPONSE, RpcResponseMessage.class);
    }

}

MessageCodecSharable

@Slf4j
@ChannelHandler.Sharable
public class MessageCodecSharable extends MessageToMessageCodec<ByteBuf, Message> {

    @Override
    protected void encode(ChannelHandlerContext ctx, Message msg, List<Object> out) throws Exception {
        ByteBuf byteBuf = ctx.alloc().buffer();
        // 1. 4个字节的魔数
        byteBuf.writeBytes(new byte[]{1, 2, 3, 4});
        // 2. 1个字节的版本号
        byteBuf.writeByte(1);
        // 3. 1字节序列化方式  0-jdk;1-json
        byteBuf.writeByte(0);
        // 4. 1字节指令类型
        byteBuf.writeByte(msg.getMessageType());
        // 5. 4字节请求
        byteBuf.writeInt(msg.getSequenceId());
        byteBuf.writeByte(0xff);
        // 6. 正文长度
        ByteArrayOutputStream bs = new ByteArrayOutputStream();
        ObjectOutputStream os = new ObjectOutputStream(bs);
        os.writeObject(msg);
        byte[] bytes = bs.toByteArray();
        byteBuf.writeInt(bytes.length);
        // 7. 正文内容
        byteBuf.writeBytes(bytes);
        out.add(byteBuf);
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf byteBuf, List<Object> out) throws Exception {

        int magicNum = byteBuf.readInt();

        byte version = byteBuf.readByte();

        byte serializer = byteBuf.readByte();

        byte messageType = byteBuf.readByte();

        int sequenceId = byteBuf.readInt();

        byteBuf.readByte();

        int length = byteBuf.readInt();
        byte[] bytes = new byte[length];
        byteBuf.readBytes(bytes, 0, length);
        // 反序列化对象

        ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes));
        Message message = (Message) ois.readObject();
//        log.info("{} {} {} {} {} {}", magicNum, version, serializer, messageType, sequenceId, length);
//        log.info(message.toString());
        out.add(message);

    }
}

ProtocolFrameDecoder

public class ProtocolFrameDecoder extends LengthFieldBasedFrameDecoder {

    public ProtocolFrameDecoder(){
        this(8 * 1024 * 1024, 12, 4, 0, 0);
    }
    public ProtocolFrameDecoder(int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, int lengthAdjustment, int initialBytesToStrip) {
        super(maxFrameLength, lengthFieldOffset, lengthFieldLength, lengthAdjustment, initialBytesToStrip);
    }
}
0%