自定义线程池处理数据

自定义线程池处理数据

通过线程池来新增存储100万条数据

前置条件

创建一个数据库

create database flex_test char set utf8mb4;

use flex_test;

CREATE TABLE IF NOT EXISTS `tb_account`
(
    `id`        INTEGER PRIMARY KEY auto_increment,
    `user_name` VARCHAR(100),
    `age`       INTEGER,
    `birthday`  DATETIME
);

通过我的前一篇博客的代码生成器可以快速创建对应的entity和service

创建线程池参数

自定义线程的7个参数含义

  • corePoolSize:线程池中的核心线程数,默认情况下核心线程会一直存在线程池中,如果将 ThreadPoolExecutor 的 allowCoreThreadTimeOut 属性设为 true,如果线程池一直闲置并超过了 keepAliveTime 所指定的时间,核心线程就会被终止。

  • maximumPoolSize:最大线程数

  • KeepAliveTime: 非核心线程数,在空闲时所存活的最大时间

  • unit:非核心线程数所存活的时间单位

  • workQueue:线程池中的任务队列,使用execute()和submit()方法提交的任务都会存储在此队列中

  • threadFactory:为线程池提供创建新线程的线程工厂(给线程起名字)

  • rejectedExecutionHandler: 线程池任务队列超过最大值之后的拒绝策略,

    • new ThreadPoolExecutor.DiscardPolicy():丢弃掉该任务,不进行处理
    • new ThreadPoolExecutor.DiscardOldestPolicy():丢弃队列里最近的一个任务,并执行当前任务
    • new ThreadPoolExecutor.AbortPolicy():直接抛出 RejectedExecutionException 异常
    • new ThreadPoolExecutor.CallerRunsPolicy():既不抛弃任务也不抛出异常,直接使用主线程来执行此任务

如何合理设置核心线程数 corePoolSize 的大小

  • 考虑系统核心资源数
  • 任务的特性来取值:如果任务是CPU密集型,则可以设置为(cpu核心数 + 1)这样可以保证cpu使用率最大化;如果任务是IO密集型,可以设置为2*cpu核心数,因为IO密集型任务在执行IO操作时线程会阻塞,此时可以有更多的任务在执行。
  • 任务的数量和频率来设置:如果提交任务的频率高,每个任务执行时间短,则可以设置大一点的核心线程数,如果任务提交频率低,执行时间长可以使用小一点的核心线程数

自适应设置核心线程大小

/**
     * 构建自适应的线程池
     * @param threadName 自定义线程前缀名
     * @return
     */
    public ThreadPoolExecutor buildFixedThreadPoolTaskExecutor(String threadName){
        //创建自适应机器本身线程数量的线程池
        int processNum = Runtime.getRuntime().availableProcessors();
        int corePoolSize = (int) (processNum / (1 - 0.2));
        int maxPoolSize = (int) (processNum / (1 - 0.5));
        ThreadPoolExecutor exec = new ThreadPoolExecutor(
                corePoolSize,
                maxPoolSize,
                5L,
                TimeUnit.SECONDS,
                new LinkedBlockingQueue<Runnable>(1024));
        if(StringUtils.isNotEmpty(threadName)){
            exec.setThreadFactory(new CustomizableThreadFactory(threadName));
        }
        exec.setRejectedExecutionHandler(new CustomRejectedHandler(exec));
        return exec;
    }

设置队列拒绝策略

默认的四种方式基本上都是直接丢弃任务或抛出异常,在某些业务场景肯定不合适这个时候需要我们自定义实现RejectedExecutionHandler,采用阻塞的方式将任务加入到任务队列

package cn.com.wuhm.common.thread;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;

/**
 * @Description
 * @Author wuhuaming
 * @Date 2023/8/28
 */
@Slf4j
@RequiredArgsConstructor
public class CustomRejectedHandler implements RejectedExecutionHandler {

    private final ThreadPoolExecutor threadPoolExecutor;
    @Override
    public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
        try{
            if(threadPoolExecutor == null || threadPoolExecutor.isShutdown()){
                return;
            }else{
                // 如果队列已经满了采用阻塞的方式将该任务重新加入到队列中
                threadPoolExecutor.getQueue().put(r);
            }
        }catch (Exception e){
            log.error("cn.com.wuhm.common.thread.CustomRejectedHandler.rejectedExecution: ",e);
        }
    }
}

数据分片

定义每个任务处理的数据量大小(默认设置100)

 /**
     * 构建任务队列
     * @param dataList
     * @param taskClass
     * @param paramsMap
     * @param <T>
     * @param <U>
     * @return
     */
    private <T, U> List<Callable<U>> buildTaskList(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
        List<Callable<U>> tasks = new ArrayList<>();
        AbstractTaskCallable<T, U> task = null;

        int dataListSize = dataList.size();
        // 总任务数
        int taskNum = dataListSize / this.dataMaxSize + 1;
        // 定义标记,过滤threadNum为整数
        boolean special = dataListSize % this.dataMaxSize == 0;
        // 定义切片list
        List<T> cutList = null;
        for (int n = 0; n < taskNum; n++) {
            if (n == taskNum - 1) {
                if (special) {
                    break;
                }
                cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, dataListSize)) ;
            } else {
                cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, this.dataMaxSize * (n + 1)));
            }

            final List<T> finalCutList = cutList;
            try {
                // 获取构造函数
                Constructor[] constructors;
                constructors = taskClass.getDeclaredConstructors();
                for(int i=0; i < constructors.length; i++){
                    // 获取构造函数的形参列表
                    Class[] parameterTypes = constructors[i].getParameterTypes();
                    // 实参数组
                    Object[] parameters = new Object[parameterTypes.length];

                    for (int j = 0; j < parameterTypes.length; j++) {
                        // 根据形参的全限定名获取bean
                        Class<?> forName = Class.forName(parameterTypes[j].getName());
                        Object bean = SpringApplicationContextUtil.getBean(forName);
                        parameters[j] = bean;
                    }
                    // 通过构造器构造实例
                    Object newInstance = constructors[i].newInstance(parameters);

                    if(newInstance instanceof AbstractTaskCallable){
                        AbstractTaskCallable<T, U> taskCallable = (AbstractTaskCallable<T, U>) newInstance;
                        // 设置切片list
                        taskCallable.setDataList(finalCutList);
                        taskCallable.setParamsMap(paramsMap);
                        task = taskCallable;
                    }
                }

            } catch (Exception e) {
                log.error("线程创建任务出错");
                throw new BasicException(new ResponseCodeImpl("500", "线程创建任务出错"));
            }
            tasks.add(task);
        }
        return tasks;
    }

定义抽象AbstractTaskCallable

package cn.com.wuhm.common.thread;

import lombok.Data;
import org.apache.commons.collections4.CollectionUtils;

import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;

/**
 * @author wuhuaming
 * @description T list的数据类型;U 线程执行后返回的结果类型
 * @date 2022-05-08 15:24
 **/
@Data
public abstract class AbstractTaskCallable<T, U> implements Callable<U> {

    /**
     * 线程处理数据集
     */
    private List<T> dataList;


    /**
     * 额外的参数
     */
    private ThreadLocal<Map<String, Object>> paramsMap;

}
  • dataList: 每个线程需要处理的数据切片
  • paramsMap:传入到线程中的额外参数,例如系统的当前的选择的项目ID,账号的租户id等等。用作数据处理需要

线程工具类

package cn.com.wuhm.common.util;

import cn.com.wuhm.common.exception.BasicException;
import cn.com.wuhm.common.exception.ResponseCodeImpl;
import cn.com.wuhm.common.thread.AbstractTaskCallable;
import cn.com.wuhm.common.thread.CustomRejectedHandler;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.concurrent.CustomizableThreadFactory;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;

/**
 * @author wuhuaming
 * @description
 * @date 2022-05-08 15:17
 **/
@Slf4j
public class ThreadUtil {

    /**
     * 线程处理数据的大小:默认是100条左右
     */
    private final Integer dataMaxSize;

    public ThreadUtil(){
        this.dataMaxSize = 100;
    }

    public ThreadUtil(Integer dataMaxSize) {
        this.dataMaxSize = dataMaxSize;
    }

    /**
     * 不带参数,没有返回值
     * @param dataList
     * @param taskClass
     * @param <T>
     * @param <U>
     */
    public <T, U> void syncThreadExecutor(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass){
        this.syncThreadExecutorWithParams(dataList, taskClass, new ThreadLocal<>());
    }

    /**
     * 线程池工具类
     * @param dataList
     * @param taskClass
     * @param <T>
     * @param <U>
     */
    public <T, U> void syncThreadExecutorWithParams(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
        this.doSyncThreadExecutor(dataList, taskClass, paramsMap, false);
    }

    /**
     * 带额外参数和返回值
     * @param dataList
     * @param taskClass
     * @param paramsMap
     * @param <T>
     * @param <U>
     * @return
     */
    public <T, U> List<U> syncThreadExecutorWithParamsV2(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
        return this.doSyncThreadExecutor(dataList, taskClass, paramsMap, true);
    }

    /**
     * 线程池执行任务
     * @param dataList
     * @param taskClass
     * @param paramsMap
     * @param isReturn
     * @param <T>
     * @param <U>
     * @return
     */
    private <T, U> List<U> doSyncThreadExecutor(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap, boolean isReturn){
        long beginTime = System.currentTimeMillis();
        List<U> outputList = new ArrayList<>();
        // 创建一个线程池
        ExecutorService exec = this.buildAdaptiveThreadPoolTaskExecutor();
        // 任务集合
        List<Callable<U>> taskList = this.buildTaskList(dataList, taskClass, paramsMap);
        try{
            if(isReturn){
                List<Future<U>> results = exec.invokeAll(taskList);
                for (Future<U> future : results) {
                    U u = future.get();
                    log.info("future value: {}", u);
                    outputList.add(u);
                }
            }else{
                exec.invokeAll(taskList);
            }
            // 关闭线程池
            exec.shutdown();
        }catch (Exception e){
            log.error("线程异常", e);
            if(e instanceof ExecutionException){
                ExecutionException executionException = (ExecutionException) e;
                log.error("ExecutionException: {}", JSONObject.parseObject(executionException.getMessage()).getString("message"));
                throw new BasicException(new ResponseCodeImpl("500", JSONObject.parseObject(executionException.getMessage()).getString("message")));
            }
        }
        log.info("线程{}任务执行结束", Thread.currentThread().getName());
        log.info("执行任务消耗了 :" + (System.currentTimeMillis() - beginTime) + "毫秒");
        return outputList;
    }

    /**
     * 构建自适应的线程池
     * @return
     */
    public ThreadPoolExecutor buildAdaptiveThreadPoolTaskExecutor(){
        return this.buildFixedThreadPoolTaskExecutor(null);
    }

    /**
     * 构建自适应的线程池
     * @param threadName 自定义线程前缀名
     * @return
     */
    public ThreadPoolExecutor buildFixedThreadPoolTaskExecutor(String threadName){
        //创建自适应机器本身线程数量的线程池
        int processNum = Runtime.getRuntime().availableProcessors();
        int corePoolSize = (int) (processNum / (1 - 0.2));
        int maxPoolSize = (int) (processNum / (1 - 0.5));
        ThreadPoolExecutor exec = new ThreadPoolExecutor(
                corePoolSize,
                maxPoolSize,
                5L,
                TimeUnit.SECONDS,
                new LinkedBlockingQueue<Runnable>(1024));
        if(StringUtils.isNotEmpty(threadName)){
            exec.setThreadFactory(new CustomizableThreadFactory(threadName));
        }
        exec.setRejectedExecutionHandler(new CustomRejectedHandler(exec));
        return exec;
    }

    /**
     * 构建任务队列
     * @param dataList
     * @param taskClass
     * @param paramsMap
     * @param <T>
     * @param <U>
     * @return
     */
    private <T, U> List<Callable<U>> buildTaskList(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
        List<Callable<U>> tasks = new ArrayList<>();
        AbstractTaskCallable<T, U> task = null;

        int dataListSize = dataList.size();
        // 总任务数
        int taskNum = dataListSize / this.dataMaxSize + 1;
        // 定义标记,过滤threadNum为整数
        boolean special = dataListSize % this.dataMaxSize == 0;
        // 定义切片list
        List<T> cutList = null;
        for (int n = 0; n < taskNum; n++) {
            if (n == taskNum - 1) {
                if (special) {
                    break;
                }
                cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, dataListSize)) ;
            } else {
                cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, this.dataMaxSize * (n + 1)));
            }

            final List<T> finalCutList = cutList;
            try {
                // 获取构造函数
                Constructor[] constructors;
                constructors = taskClass.getDeclaredConstructors();
                for(int i=0; i < constructors.length; i++){
                    // 获取构造函数的形参列表
                    Class[] parameterTypes = constructors[i].getParameterTypes();
                    // 实参数组
                    Object[] parameters = new Object[parameterTypes.length];

                    for (int j = 0; j < parameterTypes.length; j++) {
                        // 根据形参的全限定名获取bean
                        Class<?> forName = Class.forName(parameterTypes[j].getName());
                        Object bean = SpringApplicationContextUtil.getBean(forName);
                        parameters[j] = bean;
                    }
                    // 通过构造器构造实例
                    Object newInstance = constructors[i].newInstance(parameters);

                    if(newInstance instanceof AbstractTaskCallable){
                        AbstractTaskCallable<T, U> taskCallable = (AbstractTaskCallable<T, U>) newInstance;
                        // 设置切片list
                        taskCallable.setDataList(finalCutList);
                        taskCallable.setParamsMap(paramsMap);
                        task = taskCallable;
                    }
                }

            } catch (Exception e) {
                log.error("线程创建任务出错");
                throw new BasicException(new ResponseCodeImpl("500", "线程创建任务出错"));
            }
            tasks.add(task);
        }
        return tasks;
    }
}

使用线程工具类

创建InsertAccountTaskCallable

package com.example.spring.boot.test;

import cn.com.wuhm.common.thread.AbstractTaskCallable;
import com.example.spring.boot.test.entity.Account;
import com.example.spring.boot.test.service.IAccountService;
import lombok.RequiredArgsConstructor;

/**
 * @Description
 * @Author wuhuaming
 * @Date 2023/8/30
 */
@RequiredArgsConstructor
public class InsertAccountTaskCallable extends AbstractTaskCallable<Account, Integer> {

    private final IAccountService accountService;

    @Override
    public Integer call() throws Exception {
        System.out.println(this.getDataList().size());
        accountService.saveBatch(this.getDataList());
        return this.getDataList().size();
    }
}

通过IAccountService的批量新增的方法插入数据

编写springboot测试类

package com.example.spring.boot.test;

import cn.com.wuhm.common.util.ThreadUtil;
import com.example.spring.boot.test.entity.Account;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;

import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;

@SpringBootTest
class SpringbootTestApplicationTests {

    @Test
    void contextLoads() {
        long startTime = System.currentTimeMillis();
        List<Account> accountList = new ArrayList<>();
        Random random = new Random();
        for(int i = 1; i <= 1000000; i++){
            Account account = new Account();
            account.setAge(random.nextInt(100));
            account.setBirthday(new Date());
            account.setUserName("name" + i);
            account.setId(i);
            accountList.add(account);
        }
        ThreadUtil threadUtil = new ThreadUtil(1000);
        threadUtil.syncThreadExecutor(accountList, InsertAccountTaskCallable.class);
        System.out.println("总时间:" + (System.currentTimeMillis() - startTime));

    }

}

测试结果

2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert    debug 135  : ==> Parameters: 999997(Integer), name999997(String), 88(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert    debug 135  : ==> Parameters: 999998(Integer), name999998(String), 56(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert    debug 135  : ==> Parameters: 999999(Integer), name999999(String), 15(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert    debug 135  : ==> Parameters: 1000000(Integer), name1000000(String), 8(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:39.407  INFO 45132 --- [           main] cn.com.wuhm.common.util.ThreadUtil       doSyncThreadExecutor 112  : 线程main任务执行结束
2023-08-30 21:39:39.417  INFO 45132 --- [           main] cn.com.wuhm.common.util.ThreadUtil       doSyncThreadExecutor 113  : 执行任务消耗了 :1267342毫秒
总时间:1267643
2023-08-30 21:39:39.578  INFO 45132 --- [ionShutdownHook] com.zaxxer.hikari.HikariDataSource       close 350  : HikariPool-1 - Shutdown initiated...
2023-08-30 21:39:39.605  INFO 45132 --- [ionShutdownHook] com.zaxxer.hikari.HikariDataSource       close 352  : HikariPool-1 - Shutdown completed.

Process finished with exit code 0

数据库是我本地采用docker启动的mysql服务,插入100万条数据总共用时:21分钟左右;感觉是自己电脑本身性能的问题。

0%