TransmittableThreadLocal原理及使用
TransmittableThreadLocal 所做出的改进即在线程池模式下,也能够正确的将父线程本地变量传递给子线程,解决异步执行时上下文传递的问题
开源地址:https://github.com/alibaba/transmittable-thread-local
使用方法
public class TransmittableTest {
private ThreadLocal<Integer> transmittableThreadLocalInteger;
@Test
public void test1() throws InterruptedException {
transmittableThreadLocalInteger = new TransmittableThreadLocal<>();
ExecutorService executorService = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(1));
executorService.execute(this::printMessage2);
Thread.sleep(1000);
transmittableThreadLocalInteger.set(111);
Thread.sleep(1000);
executorService.execute(this::printMessage2);
Thread.sleep(1000);
printMessage2();
executorService.shutdown();
}
private void printMessage2(){
String name = Thread.currentThread().getName();
System.out.println(name + " " + transmittableThreadLocalInteger.get());
}
}
结果:
pool-1-thread-1 null
pool-1-thread-1 111
main 111
分析
1
从 executorService.execute(this::printMessage2); 开始
com.alibaba.ttl.threadpool.ExecutorTtlWrapper#execute
@Override
public void execute(@NonNull Runnable command) {
executor.execute(TtlRunnable.get(command, false, idempotent));
}
2.TtlRunnable.get
@Nullable
@Contract(value = "null, _, _ -> null; !null, _, _ -> !null", pure = true)
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
if (runnable == null) return null;
if (runnable instanceof TtlEnhanced) {
// avoid redundant decoration, and ensure idempotency
if (idempotent) return (TtlRunnable) runnable;
else throw new IllegalStateException("Already TtlRunnable!");
}
return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}
3.TtlRunnable
TtlRunnable 实现了Runnable 接口
TtlRunnable(runnable, releaseTtlValueReferenceAfterRun)
private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
this.capturedRef = new AtomicReference<>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
4.capture()
capture():捕获当前线程(父线程)中的所有 {@link TransmittableThreadLocal} 和已注册的 {@link ThreadLocal} 值
@NonNull
public static Object capture() {
final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = newHashMap(transmitteeSet.size());
for (Transmittee<Object, Object> transmittee : transmitteeSet) {
try {
transmittee2Value.put(transmittee, transmittee.capture());
} catch (Throwable t) {
if (logger.isLoggable(Level.WARNING)) {
logger.log(Level.WARNING, "exception when Transmitter.capture for transmittee " + transmittee +
"(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
}
}
}
return new Snapshot(transmittee2Value);
}
这里捕获的快照是一个 HashMap<Transmittee<Object, Object>, Object> key是 Transmittee<Object, Object> ;value 是 transmittee.capture() 的值
Transmittee 这个接口有两个实现
第一个:ttlTransmittee
private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee =
new Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() {
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = newHashMap(holder.get().size());
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
final HashMap<TransmittableThreadLocal<Object>, Object> backup = newHashMap(holder.get().size());
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set TTL values to captured
setTtlValuesTo(captured);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> clear() {
return replay(newHashMap(0));
}
@Override
public void restore(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
// call afterExecute callback
doExecuteCallback(false);
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// clear the TTL values that is not in backup
// avoid the extra TTL values after restore
if (!backup.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// restore TTL values
setTtlValuesTo(backup);
}
};
第二个:
private static final Transmittee<HashMap<ThreadLocal<Object>, Object>, HashMap<ThreadLocal<Object>, Object>> threadLocalTransmittee =
new Transmittee<HashMap<ThreadLocal<Object>, Object>, HashMap<ThreadLocal<Object>, Object>>() {
@NonNull
@Override
public HashMap<ThreadLocal<Object>, Object> capture() {
final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = newHashMap(threadLocalHolder.size());
for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
final TtlCopier<Object> copier = entry.getValue();
threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
}
return threadLocal2Value;
}
@NonNull
@Override
public HashMap<ThreadLocal<Object>, Object> replay(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {
final HashMap<ThreadLocal<Object>, Object> backup = newHashMap(captured.size());
for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
backup.put(threadLocal, threadLocal.get());
final Object value = entry.getValue();
if (value == threadLocalClearMark) threadLocal.remove();
else threadLocal.set(value);
}
return backup;
}
@NonNull
@Override
public HashMap<ThreadLocal<Object>, Object> clear() {
final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = newHashMap(threadLocalHolder.size());
for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
threadLocal2Value.put(threadLocal, threadLocalClearMark);
}
return replay(threadLocal2Value);
}
@Override
public void restore(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {
for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
final ThreadLocal<Object> threadLocal = entry.getKey();
threadLocal.set(entry.getValue());
}
}
};
所以 这里的 transmitteeSet.size() 指的是这两个对象,这里的大小是 2
ttlTransmittee
- 用途:用于存储和传输
TransmittableThreadLocal
变量的状态。 - 功能:当一个任务从一个线程转移到另一个线程时(例如通过线程池),
ttlTransmittee
会捕获当前线程中的所有TransmittableThreadLocal
变量的状态,并在目标线程中恢复这些状态。
threadLocalTransmittee
- 用途:用于存储和传输普通
ThreadLocal
变量的状态。 - 功能:类似于
ttlTransmittee
,但它专门针对普通ThreadLocal
变量。当一个任务从一个线程转移到另一个线程时,threadLocalTransmittee
也会捕获当前线程中的所有普通ThreadLocal
变量的状态,并在目标线程中恢复这些状态。
先跳过 transmittee.capture() 继续看 Snapshot(transmittee2Value)
5.Snapshot
private static class Snapshot {
final HashMap<Transmittee<Object, Object>, Object> transmittee2Value;
public Snapshot(HashMap<Transmittee<Object, Object>, Object> transmittee2Value) {
this.transmittee2Value = transmittee2Value;
}
}
到此处,回到第三步,已经初始化好一个Runnable了通过第四步将父线程中的ThreadLocal以及TransmittableThreadLocal快照保存到了 capturedRef 属性中
6.查看 get方法
回到测试类中,查看get方法
com.alibaba.ttl.TransmittableThreadLocal#get
@Override
public final T get() {
T value = super.get();
if (disableIgnoreNullValueSemantics || value != null) addThisToHolder();
return value;
}
很简单的 调用了 super.get() 去拿值,刚开始没有设置值,此时返回的是空,那么也不会执行 addThisToHolder这个方法
7.查看set方法
此时执行到了 transmittableThreadLocalInteger.set(111); 这一行代码
@Override
public final void set(T value) {
if (!disableIgnoreNullValueSemantics && value == null) {
// may set null to remove value
remove();
} else {
super.set(value);
addThisToHolder();
}
}
调用了 super.set(value); java.lang.ThreadLocal#set
执行了 addThisToHolder()
查看 addThisToHolder 方法
@SuppressWarnings("unchecked")
private void addThisToHolder() {
if (!holder.get().containsKey(this)) {
holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
}
}
holder是线程本地变量 并不是全局静态常量,每个线程独一份,因为他是 InheritableThreadLocal 类型
他会在我们赋值的时候将 TransmittableThreadLocal 引用对象添加到 WeakHashMap中value为null
到这里之后,仿佛并没有看见在哪个实现了父子线程之间值/引用的传递,回溯到第3步,在初始话线程的时候 有 capturedRef 这个属性,查看其引用的地方
点进去之后到了run方法中
8.run方法
@Override
public void run() {
final Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
final Object backup = replay(captured);
try {
runnable.run();
} finally {
restore(backup);
}
}
其底层的实现是:com.alibaba.ttl.TransmittableThreadLocal.Transmitter.Transmittee#replay
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
final HashMap<TransmittableThreadLocal<Object>, Object> backup = newHashMap(holder.get().size());
for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
TransmittableThreadLocal<Object> threadLocal = iterator.next();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!captured.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set TTL values to captured
setTtlValuesTo(captured);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
这里需要结合 第 4步中 跳过的 transmittee.capture() 方法来看
@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = newHashMap(holder.get().size());
for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
ttl2Value.put(threadLocal, threadLocal.copyValue());
}
return ttl2Value;
}
这里 我们没有设置 ThreadLocal变量,所以只用关心 ttlTransmittee 的实现即可:遍历 holder 将 TransmittableThreadLocal作为key,TransmittableThreadLocal的value作为value
通过这个方法可以知道,他在创建线程的时候,保存了这个父线程此时此刻 所有 TransmittableThreadLocal 变量的引用和状态,并将其设置到子线程中,当调用run方法的时候通过 setTtlValuesTo(captured) 将这个值设置到当前子线程中
在我们第二次提交 任务的时候,又重新创建了一个 TtlRunnable 对象,保存了这一时刻的快照,此时在main线程中已经给 变量 transmittableThreadLocalInteger 赋值,所以当运行 子线程的run方法的时候,会将捕获的父线程中的快照设置到,所以能拿到 更新以后得值。