自定义线程池处理数据
目录
自定义线程池处理数据
通过线程池来新增存储100万条数据
前置条件
创建一个数据库
create database flex_test char set utf8mb4;
use flex_test;
CREATE TABLE IF NOT EXISTS `tb_account`
(
`id` INTEGER PRIMARY KEY auto_increment,
`user_name` VARCHAR(100),
`age` INTEGER,
`birthday` DATETIME
);
通过我的前一篇博客的代码生成器可以快速创建对应的entity和service
创建线程池参数
自定义线程的7个参数含义
-
corePoolSize:线程池中的核心线程数,默认情况下核心线程会一直存在线程池中,如果将 ThreadPoolExecutor 的 allowCoreThreadTimeOut 属性设为 true,如果线程池一直闲置并超过了 keepAliveTime 所指定的时间,核心线程就会被终止。
-
maximumPoolSize:最大线程数
-
KeepAliveTime: 非核心线程数,在空闲时所存活的最大时间
-
unit:非核心线程数所存活的时间单位
-
workQueue:线程池中的任务队列,使用execute()和submit()方法提交的任务都会存储在此队列中
-
threadFactory:为线程池提供创建新线程的线程工厂(给线程起名字)
-
rejectedExecutionHandler: 线程池任务队列超过最大值之后的拒绝策略,
- new ThreadPoolExecutor.DiscardPolicy():丢弃掉该任务,不进行处理
- new ThreadPoolExecutor.DiscardOldestPolicy():丢弃队列里最近的一个任务,并执行当前任务
- new ThreadPoolExecutor.AbortPolicy():直接抛出 RejectedExecutionException 异常
- new ThreadPoolExecutor.CallerRunsPolicy():既不抛弃任务也不抛出异常,直接使用主线程来执行此任务
如何合理设置核心线程数 corePoolSize 的大小
- 考虑系统核心资源数
- 任务的特性来取值:如果任务是CPU密集型,则可以设置为(cpu核心数 + 1)这样可以保证cpu使用率最大化;如果任务是IO密集型,可以设置为2*cpu核心数,因为IO密集型任务在执行IO操作时线程会阻塞,此时可以有更多的任务在执行。
- 任务的数量和频率来设置:如果提交任务的频率高,每个任务执行时间短,则可以设置大一点的核心线程数,如果任务提交频率低,执行时间长可以使用小一点的核心线程数
自适应设置核心线程大小
/**
* 构建自适应的线程池
* @param threadName 自定义线程前缀名
* @return
*/
public ThreadPoolExecutor buildFixedThreadPoolTaskExecutor(String threadName){
//创建自适应机器本身线程数量的线程池
int processNum = Runtime.getRuntime().availableProcessors();
int corePoolSize = (int) (processNum / (1 - 0.2));
int maxPoolSize = (int) (processNum / (1 - 0.5));
ThreadPoolExecutor exec = new ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
5L,
TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(1024));
if(StringUtils.isNotEmpty(threadName)){
exec.setThreadFactory(new CustomizableThreadFactory(threadName));
}
exec.setRejectedExecutionHandler(new CustomRejectedHandler(exec));
return exec;
}
设置队列拒绝策略
默认的四种方式基本上都是直接丢弃任务或抛出异常,在某些业务场景肯定不合适这个时候需要我们自定义实现RejectedExecutionHandler
,采用阻塞的方式将任务加入到任务队列
package cn.com.wuhm.common.thread;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadPoolExecutor;
/**
* @Description
* @Author wuhuaming
* @Date 2023/8/28
*/
@Slf4j
@RequiredArgsConstructor
public class CustomRejectedHandler implements RejectedExecutionHandler {
private final ThreadPoolExecutor threadPoolExecutor;
@Override
public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
try{
if(threadPoolExecutor == null || threadPoolExecutor.isShutdown()){
return;
}else{
// 如果队列已经满了采用阻塞的方式将该任务重新加入到队列中
threadPoolExecutor.getQueue().put(r);
}
}catch (Exception e){
log.error("cn.com.wuhm.common.thread.CustomRejectedHandler.rejectedExecution: ",e);
}
}
}
数据分片
定义每个任务处理的数据量大小(默认设置100)
/**
* 构建任务队列
* @param dataList
* @param taskClass
* @param paramsMap
* @param <T>
* @param <U>
* @return
*/
private <T, U> List<Callable<U>> buildTaskList(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
List<Callable<U>> tasks = new ArrayList<>();
AbstractTaskCallable<T, U> task = null;
int dataListSize = dataList.size();
// 总任务数
int taskNum = dataListSize / this.dataMaxSize + 1;
// 定义标记,过滤threadNum为整数
boolean special = dataListSize % this.dataMaxSize == 0;
// 定义切片list
List<T> cutList = null;
for (int n = 0; n < taskNum; n++) {
if (n == taskNum - 1) {
if (special) {
break;
}
cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, dataListSize)) ;
} else {
cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, this.dataMaxSize * (n + 1)));
}
final List<T> finalCutList = cutList;
try {
// 获取构造函数
Constructor[] constructors;
constructors = taskClass.getDeclaredConstructors();
for(int i=0; i < constructors.length; i++){
// 获取构造函数的形参列表
Class[] parameterTypes = constructors[i].getParameterTypes();
// 实参数组
Object[] parameters = new Object[parameterTypes.length];
for (int j = 0; j < parameterTypes.length; j++) {
// 根据形参的全限定名获取bean
Class<?> forName = Class.forName(parameterTypes[j].getName());
Object bean = SpringApplicationContextUtil.getBean(forName);
parameters[j] = bean;
}
// 通过构造器构造实例
Object newInstance = constructors[i].newInstance(parameters);
if(newInstance instanceof AbstractTaskCallable){
AbstractTaskCallable<T, U> taskCallable = (AbstractTaskCallable<T, U>) newInstance;
// 设置切片list
taskCallable.setDataList(finalCutList);
taskCallable.setParamsMap(paramsMap);
task = taskCallable;
}
}
} catch (Exception e) {
log.error("线程创建任务出错");
throw new BasicException(new ResponseCodeImpl("500", "线程创建任务出错"));
}
tasks.add(task);
}
return tasks;
}
定义抽象AbstractTaskCallable
package cn.com.wuhm.common.thread;
import lombok.Data;
import org.apache.commons.collections4.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
/**
* @author wuhuaming
* @description T list的数据类型;U 线程执行后返回的结果类型
* @date 2022-05-08 15:24
**/
@Data
public abstract class AbstractTaskCallable<T, U> implements Callable<U> {
/**
* 线程处理数据集
*/
private List<T> dataList;
/**
* 额外的参数
*/
private ThreadLocal<Map<String, Object>> paramsMap;
}
- dataList: 每个线程需要处理的数据切片
- paramsMap:传入到线程中的额外参数,例如系统的当前的选择的项目ID,账号的租户id等等。用作数据处理需要
线程工具类
package cn.com.wuhm.common.util;
import cn.com.wuhm.common.exception.BasicException;
import cn.com.wuhm.common.exception.ResponseCodeImpl;
import cn.com.wuhm.common.thread.AbstractTaskCallable;
import cn.com.wuhm.common.thread.CustomRejectedHandler;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.scheduling.concurrent.CustomizableThreadFactory;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
/**
* @author wuhuaming
* @description
* @date 2022-05-08 15:17
**/
@Slf4j
public class ThreadUtil {
/**
* 线程处理数据的大小:默认是100条左右
*/
private final Integer dataMaxSize;
public ThreadUtil(){
this.dataMaxSize = 100;
}
public ThreadUtil(Integer dataMaxSize) {
this.dataMaxSize = dataMaxSize;
}
/**
* 不带参数,没有返回值
* @param dataList
* @param taskClass
* @param <T>
* @param <U>
*/
public <T, U> void syncThreadExecutor(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass){
this.syncThreadExecutorWithParams(dataList, taskClass, new ThreadLocal<>());
}
/**
* 线程池工具类
* @param dataList
* @param taskClass
* @param <T>
* @param <U>
*/
public <T, U> void syncThreadExecutorWithParams(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
this.doSyncThreadExecutor(dataList, taskClass, paramsMap, false);
}
/**
* 带额外参数和返回值
* @param dataList
* @param taskClass
* @param paramsMap
* @param <T>
* @param <U>
* @return
*/
public <T, U> List<U> syncThreadExecutorWithParamsV2(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
return this.doSyncThreadExecutor(dataList, taskClass, paramsMap, true);
}
/**
* 线程池执行任务
* @param dataList
* @param taskClass
* @param paramsMap
* @param isReturn
* @param <T>
* @param <U>
* @return
*/
private <T, U> List<U> doSyncThreadExecutor(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap, boolean isReturn){
long beginTime = System.currentTimeMillis();
List<U> outputList = new ArrayList<>();
// 创建一个线程池
ExecutorService exec = this.buildAdaptiveThreadPoolTaskExecutor();
// 任务集合
List<Callable<U>> taskList = this.buildTaskList(dataList, taskClass, paramsMap);
try{
if(isReturn){
List<Future<U>> results = exec.invokeAll(taskList);
for (Future<U> future : results) {
U u = future.get();
log.info("future value: {}", u);
outputList.add(u);
}
}else{
exec.invokeAll(taskList);
}
// 关闭线程池
exec.shutdown();
}catch (Exception e){
log.error("线程异常", e);
if(e instanceof ExecutionException){
ExecutionException executionException = (ExecutionException) e;
log.error("ExecutionException: {}", JSONObject.parseObject(executionException.getMessage()).getString("message"));
throw new BasicException(new ResponseCodeImpl("500", JSONObject.parseObject(executionException.getMessage()).getString("message")));
}
}
log.info("线程{}任务执行结束", Thread.currentThread().getName());
log.info("执行任务消耗了 :" + (System.currentTimeMillis() - beginTime) + "毫秒");
return outputList;
}
/**
* 构建自适应的线程池
* @return
*/
public ThreadPoolExecutor buildAdaptiveThreadPoolTaskExecutor(){
return this.buildFixedThreadPoolTaskExecutor(null);
}
/**
* 构建自适应的线程池
* @param threadName 自定义线程前缀名
* @return
*/
public ThreadPoolExecutor buildFixedThreadPoolTaskExecutor(String threadName){
//创建自适应机器本身线程数量的线程池
int processNum = Runtime.getRuntime().availableProcessors();
int corePoolSize = (int) (processNum / (1 - 0.2));
int maxPoolSize = (int) (processNum / (1 - 0.5));
ThreadPoolExecutor exec = new ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
5L,
TimeUnit.SECONDS,
new LinkedBlockingQueue<Runnable>(1024));
if(StringUtils.isNotEmpty(threadName)){
exec.setThreadFactory(new CustomizableThreadFactory(threadName));
}
exec.setRejectedExecutionHandler(new CustomRejectedHandler(exec));
return exec;
}
/**
* 构建任务队列
* @param dataList
* @param taskClass
* @param paramsMap
* @param <T>
* @param <U>
* @return
*/
private <T, U> List<Callable<U>> buildTaskList(List<T> dataList, Class< ? extends AbstractTaskCallable<T, U>> taskClass, ThreadLocal<Map<String,Object>> paramsMap){
List<Callable<U>> tasks = new ArrayList<>();
AbstractTaskCallable<T, U> task = null;
int dataListSize = dataList.size();
// 总任务数
int taskNum = dataListSize / this.dataMaxSize + 1;
// 定义标记,过滤threadNum为整数
boolean special = dataListSize % this.dataMaxSize == 0;
// 定义切片list
List<T> cutList = null;
for (int n = 0; n < taskNum; n++) {
if (n == taskNum - 1) {
if (special) {
break;
}
cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, dataListSize)) ;
} else {
cutList = new ArrayList(dataList.subList(this.dataMaxSize * n, this.dataMaxSize * (n + 1)));
}
final List<T> finalCutList = cutList;
try {
// 获取构造函数
Constructor[] constructors;
constructors = taskClass.getDeclaredConstructors();
for(int i=0; i < constructors.length; i++){
// 获取构造函数的形参列表
Class[] parameterTypes = constructors[i].getParameterTypes();
// 实参数组
Object[] parameters = new Object[parameterTypes.length];
for (int j = 0; j < parameterTypes.length; j++) {
// 根据形参的全限定名获取bean
Class<?> forName = Class.forName(parameterTypes[j].getName());
Object bean = SpringApplicationContextUtil.getBean(forName);
parameters[j] = bean;
}
// 通过构造器构造实例
Object newInstance = constructors[i].newInstance(parameters);
if(newInstance instanceof AbstractTaskCallable){
AbstractTaskCallable<T, U> taskCallable = (AbstractTaskCallable<T, U>) newInstance;
// 设置切片list
taskCallable.setDataList(finalCutList);
taskCallable.setParamsMap(paramsMap);
task = taskCallable;
}
}
} catch (Exception e) {
log.error("线程创建任务出错");
throw new BasicException(new ResponseCodeImpl("500", "线程创建任务出错"));
}
tasks.add(task);
}
return tasks;
}
}
使用线程工具类
创建InsertAccountTaskCallable
package com.example.spring.boot.test;
import cn.com.wuhm.common.thread.AbstractTaskCallable;
import com.example.spring.boot.test.entity.Account;
import com.example.spring.boot.test.service.IAccountService;
import lombok.RequiredArgsConstructor;
/**
* @Description
* @Author wuhuaming
* @Date 2023/8/30
*/
@RequiredArgsConstructor
public class InsertAccountTaskCallable extends AbstractTaskCallable<Account, Integer> {
private final IAccountService accountService;
@Override
public Integer call() throws Exception {
System.out.println(this.getDataList().size());
accountService.saveBatch(this.getDataList());
return this.getDataList().size();
}
}
通过IAccountService
的批量新增的方法插入数据
编写springboot测试类
package com.example.spring.boot.test;
import cn.com.wuhm.common.util.ThreadUtil;
import com.example.spring.boot.test.entity.Account;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;
@SpringBootTest
class SpringbootTestApplicationTests {
@Test
void contextLoads() {
long startTime = System.currentTimeMillis();
List<Account> accountList = new ArrayList<>();
Random random = new Random();
for(int i = 1; i <= 1000000; i++){
Account account = new Account();
account.setAge(random.nextInt(100));
account.setBirthday(new Date());
account.setUserName("name" + i);
account.setId(i);
accountList.add(account);
}
ThreadUtil threadUtil = new ThreadUtil(1000);
threadUtil.syncThreadExecutor(accountList, InsertAccountTaskCallable.class);
System.out.println("总时间:" + (System.currentTimeMillis() - startTime));
}
}
测试结果
2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert debug 135 : ==> Parameters: 999997(Integer), name999997(String), 88(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert debug 135 : ==> Parameters: 999998(Integer), name999998(String), 56(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert debug 135 : ==> Parameters: 999999(Integer), name999999(String), 15(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:32.256 DEBUG 45132 --- [ool-2-thread-10] c.e.s.b.t.mapper.AccountMapper.insert debug 135 : ==> Parameters: 1000000(Integer), name1000000(String), 8(Integer), 2023-08-30 21:18:32.074(Timestamp)
2023-08-30 21:39:39.407 INFO 45132 --- [ main] cn.com.wuhm.common.util.ThreadUtil doSyncThreadExecutor 112 : 线程main任务执行结束
2023-08-30 21:39:39.417 INFO 45132 --- [ main] cn.com.wuhm.common.util.ThreadUtil doSyncThreadExecutor 113 : 执行任务消耗了 :1267342毫秒
总时间:1267643
2023-08-30 21:39:39.578 INFO 45132 --- [ionShutdownHook] com.zaxxer.hikari.HikariDataSource close 350 : HikariPool-1 - Shutdown initiated...
2023-08-30 21:39:39.605 INFO 45132 --- [ionShutdownHook] com.zaxxer.hikari.HikariDataSource close 352 : HikariPool-1 - Shutdown completed.
Process finished with exit code 0
数据库是我本地采用docker启动的mysql服务,插入100万条数据总共用时:21分钟左右;感觉是自己电脑本身性能的问题。