由于并发的需要原因,使用CompletableFuture异步执行Dubbo接口,RpcContext中存储了tenantId等信息。上线一段时间后,发现有些时候拿到的上下文并不是自己线程的上下文。
- 原因分析
CompletableFuture.supplyAsync内部使用ForkJoinPool执行。
要知道原因,需要了解forkjoin的原理,forkjoin其核心思想就是分而治之。使用递归的思想将一个大人物拆分为多个小任务,直到达到停止拆分的条件。
并且他每个线程(线程数默认为cpu数)都有一个无限的执行队列。线程会从执行队列里面取任务执行。并且执行过程中,如果某些线程执行的快,为了利用cpu,空闲的线程会偷取其他队列里面的线程,拿到自己队列并执行。当然,为了避免竞争,队列使用的是双向队列,自己线程从队列头获取任务,偷取任务从队列尾部获取。这里我找了一张图很好的描述了一下这个场景:
因此,部分情况下执行任务的线程不是保存了ThreadLocal信息的线程,而是窃取任务的线程。 - 解决思路
知道了原因,我们只需要修改ForkJoinPool,让它获取到正确的ThreadLocal信息。
private static class SafeForkJoinPool extends ForkJoinPool {
private static Field inheritableThreadLocalsField;
private static Method createInheritedMapMethod;
static {
try {
inheritableThreadLocalsField = Thread.class.getDeclaredField("inheritableThreadLocals");
inheritableThreadLocalsField.setAccessible(true);
createInheritedMapMethod = ThreadLocal.class.getDeclaredMethod("createInheritedMap", new Class[]{inheritableThreadLocalsField.getType()});
createInheritedMapMethod.setAccessible(true);
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}
public SafeForkJoinPool() {
}
public SafeForkJoinPool(int parallelism) {
super(parallelism);
}
public SafeForkJoinPool(int parallelism,
ForkJoinPool.ForkJoinWorkerThreadFactory factory,
Thread.UncaughtExceptionHandler handler,
boolean asyncMode) {
super(parallelism, factory, handler, asyncMode);
}
@Override
public void execute(Runnable task) {
//获取当前线程中的所有ThreadLocal数据
Object parentLocals = getField(inheritableThreadLocalsField, Thread.currentThread());
super.execute(() -> {
Object locals = null == parentLocals ? null : invokeMethod(createInheritedMapMethod, null, new Object[]{parentLocals});
//替换当前线程(包含窃取线程)中的所有ThreadLocal数据
setField(inheritableThreadLocalsField, Thread.currentThread(), locals);
task.run();
});
}
private static Object invokeMethod(Method method, Object target, Object... args) {
try {
return method.invoke(target, args);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static Object getField(Field field, Object target) {
try {
return field.get(target);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static void setField(Field field, Object target, Object value) {
try {
field.set(target, value);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}