TransmittableThreadLocal原理及使用

TransmittableThreadLocal 所做出的改进即在线程池模式下,也能够正确的将父线程本地变量传递给子线程,解决异步执行时上下文传递的问题

开源地址:https://github.com/alibaba/transmittable-thread-local

使用方法

public class TransmittableTest {

    private ThreadLocal<Integer> transmittableThreadLocalInteger;

    @Test
    public void test1() throws InterruptedException {
        transmittableThreadLocalInteger = new TransmittableThreadLocal<>();
        ExecutorService executorService = TtlExecutors.getTtlExecutorService(Executors.newFixedThreadPool(1));
        executorService.execute(this::printMessage2);
        Thread.sleep(1000);
        transmittableThreadLocalInteger.set(111);
        Thread.sleep(1000);
        executorService.execute(this::printMessage2);
        Thread.sleep(1000);
        printMessage2();
        executorService.shutdown();
    }

    private void printMessage2(){
        String name = Thread.currentThread().getName();
        System.out.println(name + " " + transmittableThreadLocalInteger.get());
    }
}

结果:

pool-1-thread-1 null
pool-1-thread-1 111
main 111

分析

1

从 executorService.execute(this::printMessage2); 开始

com.alibaba.ttl.threadpool.ExecutorTtlWrapper#execute

@Override
public void execute(@NonNull Runnable command) {
    executor.execute(TtlRunnable.get(command, false, idempotent));
}

2.TtlRunnable.get

@Nullable
@Contract(value = "null, _, _ -> null; !null, _, _ -> !null", pure = true)
public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
    if (runnable == null) return null;

    if (runnable instanceof TtlEnhanced) {
        // avoid redundant decoration, and ensure idempotency
        if (idempotent) return (TtlRunnable) runnable;
        else throw new IllegalStateException("Already TtlRunnable!");
    }
    return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
}

3.TtlRunnable

TtlRunnable 实现了Runnable 接口

TtlRunnable(runnable, releaseTtlValueReferenceAfterRun)

private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
    this.capturedRef = new AtomicReference<>(capture());
    this.runnable = runnable;
    this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}

4.capture()

capture():捕获当前线程(父线程)中的所有 {@link TransmittableThreadLocal} 和已注册的 {@link ThreadLocal} 值

@NonNull
public static Object capture() {
    final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = newHashMap(transmitteeSet.size());
    for (Transmittee<Object, Object> transmittee : transmitteeSet) {
        try {
            transmittee2Value.put(transmittee, transmittee.capture());
        } catch (Throwable t) {
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "exception when Transmitter.capture for transmittee " + transmittee +
                        "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t);
            }
        }
    }
    return new Snapshot(transmittee2Value);
}

这里捕获的快照是一个 HashMap<Transmittee<Object, Object>, Object> key是 Transmittee<Object, Object> ;value 是 transmittee.capture() 的值

Transmittee 这个接口有两个实现

第一个:ttlTransmittee

private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee =
                new Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() {
                    @NonNull
                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
                        final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = newHashMap(holder.get().size());
                        for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
                            ttl2Value.put(threadLocal, threadLocal.copyValue());
                        }
                        return ttl2Value;
                    }

                    @NonNull
                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
                        final HashMap<TransmittableThreadLocal<Object>, Object> backup = newHashMap(holder.get().size());

                        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                            TransmittableThreadLocal<Object> threadLocal = iterator.next();

                            // backup
                            backup.put(threadLocal, threadLocal.get());

                            // clear the TTL values that is not in captured
                            // avoid the extra TTL values after replay when run task
                            if (!captured.containsKey(threadLocal)) {
                                iterator.remove();
                                threadLocal.superRemove();
                            }
                        }

                        // set TTL values to captured
                        setTtlValuesTo(captured);

                        // call beforeExecute callback
                        doExecuteCallback(true);

                        return backup;
                    }

                    @NonNull
                    @Override
                    public HashMap<TransmittableThreadLocal<Object>, Object> clear() {
                        return replay(newHashMap(0));
                    }

                    @Override
                    public void restore(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
                        // call afterExecute callback
                        doExecuteCallback(false);

                        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                            TransmittableThreadLocal<Object> threadLocal = iterator.next();

                            // clear the TTL values that is not in backup
                            // avoid the extra TTL values after restore
                            if (!backup.containsKey(threadLocal)) {
                                iterator.remove();
                                threadLocal.superRemove();
                            }
                        }

                        // restore TTL values
                        setTtlValuesTo(backup);
                    }
                };

第二个:

private static final Transmittee<HashMap<ThreadLocal<Object>, Object>, HashMap<ThreadLocal<Object>, Object>> threadLocalTransmittee =
                new Transmittee<HashMap<ThreadLocal<Object>, Object>, HashMap<ThreadLocal<Object>, Object>>() {
                    @NonNull
                    @Override
                    public HashMap<ThreadLocal<Object>, Object> capture() {
                        final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = newHashMap(threadLocalHolder.size());
                        for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
                            final ThreadLocal<Object> threadLocal = entry.getKey();
                            final TtlCopier<Object> copier = entry.getValue();

                            threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
                        }
                        return threadLocal2Value;
                    }

                    @NonNull
                    @Override
                    public HashMap<ThreadLocal<Object>, Object> replay(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {
                        final HashMap<ThreadLocal<Object>, Object> backup = newHashMap(captured.size());

                        for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
                            final ThreadLocal<Object> threadLocal = entry.getKey();
                            backup.put(threadLocal, threadLocal.get());

                            final Object value = entry.getValue();
                            if (value == threadLocalClearMark) threadLocal.remove();
                            else threadLocal.set(value);
                        }

                        return backup;
                    }

                    @NonNull
                    @Override
                    public HashMap<ThreadLocal<Object>, Object> clear() {
                        final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = newHashMap(threadLocalHolder.size());

                        for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
                            final ThreadLocal<Object> threadLocal = entry.getKey();
                            threadLocal2Value.put(threadLocal, threadLocalClearMark);
                        }

                        return replay(threadLocal2Value);
                    }

                    @Override
                    public void restore(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {
                        for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
                            final ThreadLocal<Object> threadLocal = entry.getKey();
                            threadLocal.set(entry.getValue());
                        }
                    }
                };

所以 这里的 transmitteeSet.size() 指的是这两个对象,这里的大小是 2

ttlTransmittee

  • 用途:用于存储和传输 TransmittableThreadLocal 变量的状态。
  • 功能:当一个任务从一个线程转移到另一个线程时(例如通过线程池),ttlTransmittee 会捕获当前线程中的所有 TransmittableThreadLocal 变量的状态,并在目标线程中恢复这些状态。

threadLocalTransmittee

  • 用途:用于存储和传输普通 ThreadLocal 变量的状态。
  • 功能:类似于 ttlTransmittee,但它专门针对普通 ThreadLocal 变量。当一个任务从一个线程转移到另一个线程时,threadLocalTransmittee 也会捕获当前线程中的所有普通 ThreadLocal 变量的状态,并在目标线程中恢复这些状态。

/images/java/21-1.png

先跳过 transmittee.capture() 继续看 Snapshot(transmittee2Value)

5.Snapshot

private static class Snapshot {
    final HashMap<Transmittee<Object, Object>, Object> transmittee2Value;

    public Snapshot(HashMap<Transmittee<Object, Object>, Object> transmittee2Value) {
        this.transmittee2Value = transmittee2Value;
    }
}

到此处,回到第三步,已经初始化好一个Runnable了通过第四步将父线程中的ThreadLocal以及TransmittableThreadLocal快照保存到了 capturedRef 属性中

/images/java/21-2.png

6.查看 get方法

回到测试类中,查看get方法

/images/java/21-3.png

com.alibaba.ttl.TransmittableThreadLocal#get

@Override
public final T get() {
    T value = super.get();
    if (disableIgnoreNullValueSemantics || value != null) addThisToHolder();
    return value;
}

很简单的 调用了 super.get() 去拿值,刚开始没有设置值,此时返回的是空,那么也不会执行 addThisToHolder这个方法

7.查看set方法

此时执行到了 transmittableThreadLocalInteger.set(111); 这一行代码

@Override
public final void set(T value) {
    if (!disableIgnoreNullValueSemantics && value == null) {
        // may set null to remove value
        remove();
    } else {
        super.set(value);
        addThisToHolder();
    }
}

调用了 super.set(value); java.lang.ThreadLocal#set

执行了 addThisToHolder()

查看 addThisToHolder 方法

@SuppressWarnings("unchecked")
private void addThisToHolder() {
    if (!holder.get().containsKey(this)) {
        holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
    }
}

/images/java/21-4.png

holder是线程本地变量 并不是全局静态常量,每个线程独一份,因为他是 InheritableThreadLocal 类型

他会在我们赋值的时候将 TransmittableThreadLocal 引用对象添加到 WeakHashMap中value为null

到这里之后,仿佛并没有看见在哪个实现了父子线程之间值/引用的传递,回溯到第3步,在初始话线程的时候 有 capturedRef 这个属性,查看其引用的地方

/images/java/21-5.png

点进去之后到了run方法中

8.run方法

@Override
public void run() {
    final Object captured = capturedRef.get();
    if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
    }

    final Object backup = replay(captured);
    try {
        runnable.run();
    } finally {
        restore(backup);
    }
}

其底层的实现是:com.alibaba.ttl.TransmittableThreadLocal.Transmitter.Transmittee#replay

@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> replay(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
    final HashMap<TransmittableThreadLocal<Object>, Object> backup = newHashMap(holder.get().size());

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // backup
        backup.put(threadLocal, threadLocal.get());

        // clear the TTL values that is not in captured
        // avoid the extra TTL values after replay when run task
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // set TTL values to captured
    setTtlValuesTo(captured);

    // call beforeExecute callback
    doExecuteCallback(true);

    return backup;
}

这里需要结合 第 4步中 跳过的 transmittee.capture() 方法来看

@NonNull
@Override
public HashMap<TransmittableThreadLocal<Object>, Object> capture() {
    final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = newHashMap(holder.get().size());
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    return ttl2Value;
}

这里 我们没有设置 ThreadLocal变量,所以只用关心 ttlTransmittee 的实现即可:遍历 holder 将 TransmittableThreadLocal作为key,TransmittableThreadLocal的value作为value

通过这个方法可以知道,他在创建线程的时候,保存了这个父线程此时此刻 所有 TransmittableThreadLocal 变量的引用和状态,并将其设置到子线程中,当调用run方法的时候通过 setTtlValuesTo(captured) 将这个值设置到当前子线程中

在我们第二次提交 任务的时候,又重新创建了一个 TtlRunnable 对象,保存了这一时刻的快照,此时在main线程中已经给 变量 transmittableThreadLocalInteger 赋值,所以当运行 子线程的run方法的时候,会将捕获的父线程中的快照设置到,所以能拿到 更新以后得值。

0%