在并发编程领域,任务模型可以分为简单并行任务、聚合任务和批量并行任务。然而,还有一种广泛应用的任务模型——分治(Divide and Conquer)。分治是一种解决复杂问题的思维方法,它通过将复杂问题分解为多个相似的子问题,再将这些子问题进一步分解,直到每个子问题变得足够简单从而可以直接求解。
1 什么是分治任务模型?
分治任务模型主要分为两个阶段:
- 任务分解:迭代地将任务分解为子任务,直到子任务可以直接计算出结果。
- 结果合并:逐层合并子任务的执行结果,直到获得最终结果。
在这个模型中,任务和分解后的子任务具有相似性,这种相似性通常体现在任务和子任务的算法是相同的,但计算的数据规模不同。因此,分治任务模型常常采用递归算法来实现。
2 Fork/Join 并行计算框架
Fork/Join 是一个并行计算框架,主要用于支持分治任务模型。在这个框架中,“Fork”代表任务的分解,而“Join”代表结果的合并。Fork/Join 框架主要由两部分组成:分治任务的线程池 ForkJoinPool
和分治任务 ForkJoinTask
。
2.1 ForkJoinPool 和 ForkJoinTask
ForkJoinPool
和 ForkJoinTask
的关系类似于 ThreadPoolExecutor
和 Runnable
之间的关系,都是用于提交任务到线程池的。不过,分治任务有自己独特的类型 ForkJoinTask
。
ForkJoinTask
是一个抽象类,其中有许多方法,最核心的是 fork()
方法和 join()
方法。fork()
方法用于异步执行一个子任务,而 join()
方法通过阻塞当前线程来等待子任务的执行结果。
ForkJoinTask
有两个子类:
- RecursiveAction:用于没有返回值的任务。
- RecursiveTask:用于有返回值的任务。
这两个子类都定义了一个抽象方法 compute()
,不同之处在于 RecursiveAction
的 compute
方法没有返回值,而 RecursiveTask
的 compute
方法有返回值。这两个子类也都是抽象类,在使用时需要创建自定义的子类来扩展功能。
2.2 示例:计算斐波那契数列
接下来,让我们使用 Fork/Join 并行计算框架来计算斐波那契数列。首先,我们需要创建一个 ForkJoinPool
线程池以及一个用于计算斐波那契数列的 Fibonacci
分治任务。然后,通过调用 ForkJoinPool
线程池的 invoke()
方法来启动分治任务。
由于计算斐波那契数列需要返回结果,所以我们的 Fibonacci
类继承自 RecursiveTask
。Fibonacci
分治任务需要实现 compute
方法,在这个方法中,逻辑与普通计算斐波那契数列的方法非常相似,只是在计算 Fibonacci(n - 1)
时使用了异步子任务,这是通过 f1.fork()
语句来实现的。
@Slf4j
public class ForkJoinDemo {
// 1. 运行入口
public static void main(String[] args) {
int n = 20;
// 为了追踪子线程名称,需要重写 ForkJoinWorkerThreadFactory 的方法
final ForkJoinPool.ForkJoinWorkerThreadFactory factory = pool -> {
final ForkJoinWorkerThread worker = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
worker.setName("my-thread" + worker.getPoolIndex());
return worker;
};
// 创建分治任务线程池,可以追踪到线程名称
ForkJoinPool forkJoinPool = new ForkJoinPool(4, factory, null, false);
// 快速创建 ForkJoinPool 方法
// ForkJoinPool forkJoinPool = new ForkJoinPool(4);
// 创建分治任务
Fibonacci fibonacci = new Fibonacci(n);
// 调用 invoke 方法启动分治任务
Integer result = forkJoinPool.invoke(fibonacci);
log.info("Fibonacci {} 的结果是 {}", n, result);
}
}
// 2. 定义拆分任务,写好拆分逻辑
@Slf4j
class Fibonacci extends RecursiveTask<Integer> {
final int n;
Fibonacci(int n) {
this.n = n;
}
@Override
public Integer compute() {
// 和递归类似,定义可计算的最小单元
if (n <= 1) {
return n;
}
// 想查看子线程名称输出的可以打开下面注释
// log.info(Thread.currentThread().getName());
Fibonacci f1 = new Fibonacci(n - 1);
// 拆分成子任务
f1.fork();
Fibonacci f2 = new Fibonacci(n - 2);
// f1.join 等待子任务执行结果
return f2.compute() + f1.join();
}
}
运行上述程序,我们会得到如下结果:
17:29:10.336 [main] INFO tech.shuyi.javacodechip.forkjoinpool.ForkJoinDemo - Fibonacci 20 的结果是 6765
4 ForkJoinPool 的工作原理
Fork/Join 并行计算框架的核心组件是 ForkJoinPool
。下面我们将详细介绍 ForkJoinPool
的工作原理。
4.1 任务提交与分配
当我们通过 ForkJoinPool
的 invoke
或 submit
方法提交任务时,ForkJoinPool
会根据一定的路由规则将任务分配到一个任务队列中。如果任务在执行过程中创建了子任务,那么这些子任务会被提交到对应工作线程的任务队列中。
ForkJoinPool
中有一个数组形式的成员变量 workQueue[]
,它对应一个队列数组,每个队列对应一个消费线程。丢入线程池的任务会根据特定规则进行转发。
4.2 任务窃取机制
当工作线程的任务队列为空时,它是否无事可做呢?答案是否定的。ForkJoinPool
引入了一种称为“任务窃取”的机制。当工作线程空闲时,它可以从其他工作线程的任务队列中“窃取”任务。
例如,下图中线程 T2 的任务队列已经为空,它可以窃取线程 T1 任务队列中的任务。这样,所有的工作线程都能保持忙碌的状态。
线程 T1: [任务1, 任务2, 任务3]
线程 T2: []
线程 T2 窃取线程 T1 的任务3
线程 T1: [任务1, 任务2]
线程 T2: [任务3]
4.3 双端队列
ForkJoinPool
中的任务队列采用双端队列(Deque)的形式。工作线程从任务队列的一端获取任务,而“窃取任务”从另一端进行消费。这种设计能够避免许多不必要的数据竞争。
具体来说,工作线程从队列的头部获取任务,而窃取任务的线程从队列的尾部获取任务。这种设计确保了任务的分配和窃取过程不会相互干扰,从而提高了并发执行的效率。
4.4 示例代码
以下是一个简单的示例,展示了如何使用 ForkJoinPool
来处理任务:
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
public class ForkJoinExample {
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool();
int[] array = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
SumTask task = new SumTask(array, 0, array.length);
int result = pool.invoke(task);
System.out.println("Sum: " + result);
}
}
class SumTask extends RecursiveTask<Integer> {
private final int[] array;
private final int start;
private final int end;
SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
if (end - start <= 2) {
int sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
} else {
int mid = (start + end) / 2;
SumTask leftTask = new SumTask(array, start, mid);
SumTask rightTask = new SumTask(array, mid, end);
leftTask.fork();
int rightResult = rightTask.compute();
int leftResult = leftTask.join();
return leftResult + rightResult;
}
}
}
在这个示例中,我们创建了一个 ForkJoinPool
和一个 SumTask
任务,用于计算数组中元素的总和。SumTask
任务通过递归地将数组分成两部分,分别计算每部分的总和,最后将结果合并。
5 ForkJoinPool 与 ThreadPoolExecutor 的比较
虽然 ForkJoinPool 和 ThreadPoolExecutor 都是线程池,用于执行任务,但它们之间有很多不同之处:
- 任务窃取 vs 工作复用:ForkJoinPool 采用任务窃取机制,而 ThreadPoolExecutor 采用工作复用机制。
- 分治任务模型 vs 简单并行任务模型:ForkJoinPool 适用于分治任务模型,而 ThreadPoolExecutor 适用于简单并行任务模型。
- LIFO 任务队列 vs FIFO 任务队列:ForkJoinPool 采用 LIFO 任务队列,而 ThreadPoolExecutor 采用 FIFO 任务队列。
6 示例:计算 1 到 1 亿的和
在并发编程中,计算 1 到 1 亿的和是一个典型的并行计算任务。为了加快计算速度,我们可以利用分治原理,将任务分解为多个子任务,并利用 CPU 的并发计算性能来缩短计算时间。下面将分别使用 ThreadPoolExecutor
和 ForkJoinPool
来实现这一任务,并对比两者的实现方式。
6.1 使用 ThreadPoolExecutor 实现
首先,我们定义一个 Calculator
接口,表示计算数字总和的动作:
public interface Calculator {
/**
* 把传进来的所有 numbers 做求和处理
*
* @param numbers
* @return 总和
*/
long sumUp(long[] numbers);
}
接着,我们定义一个使用 ThreadPoolExecutor
线程池实现的类 ExecutorServiceCalculator
:
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
public class ExecutorServiceCalculator implements Calculator {
private int parallism;
private ExecutorService pool;
public ExecutorServiceCalculator() {
// CPU的核心数 默认就用cpu核心数了
parallism = Runtime.getRuntime().availableProcessors();
pool = Executors.newFixedThreadPool(parallism);
}
// 1. 处理计算任务的线程
private static class SumTask implements Callable<Long> {
private long[] numbers;
private int from;
private int to;
public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
public Long call() {
long total = 0;
for (int i = from; i <= to; i++) {
total += numbers[i];
}
return total;
}
}
// 2. 核心业务逻辑实现
@Override
public long sumUp(long[] numbers) {
List<Future<Long>> results = new ArrayList<>();
// 2.1 数字拆分
// 把任务分解为 n 份,交给 n 个线程处理 4核心 就等分成4份呗
// 然后把每一份都扔个一个SumTask线程 进行处理
int part = numbers.length / parallism;
for (int i = 0; i < parallism; i++) {
int from = i * part; //开始位置
int to = (i == parallism - 1) ? numbers.length - 1 : (i + 1) * part - 1; //结束位置
//扔给线程池计算
results.add(pool.submit(new SumTask(numbers, from, to)));
}
// 2.2 阻塞等待结果
// 把每个线程的结果相加,得到最终结果 get()方法 是阻塞的
// 优化方案:可以采用CompletableFuture来优化 JDK1.8的新特性
long total = 0L;
for (Future<Long> f : results) {
try {
total += f.get();
} catch (Exception ignore) {
}
}
return total;
}
public static void main(String[] args) {
// 创建一个包含 1 亿个数字的数组
long[] numbers = new long[100000000];
for (int i = 0; i < numbers.length; i++) {
numbers[i] = i + 1;
}
// 创建 ExecutorServiceCalculator 实例
ExecutorServiceCalculator calculator = new ExecutorServiceCalculator();
// 记录开始时间
long startTime = System.currentTimeMillis();
// 计算总和
long result = calculator.sumUp(numbers);
// 记录结束时间
long endTime = System.currentTimeMillis();
// 输出结果和耗时
System.out.println("结果为:" + result);
System.out.println("耗时:" + (endTime - startTime) + "ms");
}
}
在这个实现中,我们首先将 1 亿个数字拆分为多个子任务,每个子任务计算一部分数字的总和。然后,我们将这些子任务提交给 ThreadPoolExecutor
进行并行计算,最后通过 Future
接口获取每个子任务的结果并累加。
6.2 使用 ForkJoinPool 实现
接下来,我们使用 ForkJoinPool
来实现相同的任务。首先,我们定义一个 SumTask
类,继承自 RecursiveTask
抽象类,并在 compute
方法中定义拆分逻辑及计算:
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
public class ForkJoinCalculator implements Calculator {
private ForkJoinPool pool;
// 1. 定义计算逻辑
private static class SumTask extends RecursiveTask<Long> {
private long[] numbers;
private int from;
private int to;
public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
//此方法为ForkJoin的核心方法:对任务进行拆分 拆分的好坏决定了效率的高低
@Override
protected Long compute() {
// 当需要计算的数字个数小于6时,直接采用for loop方式计算结果
if (to - from < 6) {
long total = 0;
for (int i = from; i <= to; i++) {
total += numbers[i];
}
return total;
} else {
// 否则,把任务一分为二,递归拆分(注意此处有递归)到底拆分成多少分 需要根据具体情况而定
int middle = (from + to) / 2;
SumTask taskLeft = new SumTask(numbers, from, middle);
SumTask taskRight = new SumTask(numbers, middle + 1, to);
taskLeft.fork();
taskRight.fork();
return taskLeft.join() + taskRight.join();
}
}
}
public ForkJoinCalculator() {
// 也可以使用公用的线程池 ForkJoinPool.commonPool():
// pool = ForkJoinPool.commonPool()
pool = new ForkJoinPool();
}
@Override
public long sumUp(long[] numbers) {
Long result = pool.invoke(new SumTask(numbers, 0, numbers.length - 1));
pool.shutdown();
return result;
}
public static void main(String[] args) {
// 创建一个包含 1 亿个数字的数组
long[] numbers = new long[100000000];
for (int i = 0; i < numbers.length; i++) {
numbers[i] = i + 1;
}
// 创建 ForkJoinCalculator 实例
ForkJoinCalculator calculator = new ForkJoinCalculator();
// 记录开始时间
long startTime = System.currentTimeMillis();
// 计算总和
long result = calculator.sumUp(numbers);
// 记录结束时间
long endTime = System.currentTimeMillis();
// 输出结果和耗时
System.out.println("结果为:" + result);
System.out.println("耗时:" + (endTime - startTime) + "ms");
}
}
在这个实现中,我们同样将 1 亿个数字拆分为多个子任务,每个子任务计算一部分数字的总和。不同的是,我们使用 ForkJoinPool
和 RecursiveTask
来实现任务的拆分和合并。ForkJoinPool
会自动处理任务的分配和结果的合并,无需手动获取子任务的结果。
6.3 对比两种实现方式
通过上述两种实现方式,我们可以看到它们都有任务拆分的逻辑,以及最终合并数值的逻辑。但 ForkJoinPool
相比 ThreadPoolExecutor
来说,做了一些实现上的封装,例如:
- 自动获取子任务结果:
ForkJoinPool
使用join
方法直接获取子任务的结果,而ThreadPoolExecutor
需要手动通过Future
接口获取结果。 - 任务拆分逻辑的封装:
ForkJoinPool
将任务拆分的逻辑封装在RecursiveTask
实现类中,而不是裸露在外。
因此,对于没有父子任务依赖,但是希望获取到子任务执行结果的并行计算任务,使用 ForkJoinPool
实现更加方便,封装做得更好。
6.4 运行结果
使用 ThreadPoolExecutor
实现的运行结果:
结果为:5000000050000000
耗时:66ms
使用 ForkJoinPool
实现的运行结果:
结果为:5000000050000000
耗时:764ms
从运行结果来看,ThreadPoolExecutor
的实现速度更快,这可能是因为 ForkJoinPool
在处理大量小任务时,任务窃取机制的开销较大。但在实际应用中,ForkJoinPool
更适合处理那些可以递归分解的任务,如计算斐波那契数列、归并排序等。
7 模拟 MapReduce 统计单词数量
MapReduce 是一个编程模型,同时也是一个处理和生成大数据集的处理框架。它源于 Google,用于支持在大型数据集上的分布式计算。这个框架主要由两个步骤组成:Map 步骤和 Reduce 步骤,这也是它名字的由来。
Fork/Join 并行计算框架通常被用来实现学习 MapReduce 的入门程序,该程序用于统计文件中每个单词的数量。本文将介绍如何使用 Fork/Join 框架来模拟 MapReduce 的单词统计任务。
7.1 实现代码
下面的代码使用了字符串数组 String[] fc
来模拟文件内容,其中每个元素与文件中的行数据一一对应。关键代码位于 compute()
方法中,这是一个递归方法。它将前半部分数据 fork 一个递归任务进行处理(关键代码:mr1.fork()
),而后半部分数据在当前任务中递归处理(mr2.compute()
)。
import java.util.concurrent.RecursiveTask;
public class WordCountTask extends RecursiveTask<Integer> {
private final String[] fc;
private final int start, end;
public WordCountTask(String[] fc, int start, int end) {
this.fc = fc;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
if (end - start <= 1) {
// 对单行数据进行统计
return countWords(fc[start]);
} else {
int mid = (start + end) / 2;
WordCountTask mr1 = new WordCountTask(fc, start, mid);
mr1.fork();
WordCountTask mr2 = new WordCountTask(fc, mid, end);
int result2 = mr2.compute();
int result1 = mr1.join();
// 汇总结果
return result1 + result2;
}
}
private int countWords(String line) {
String[] words = line.split(" ");
return words.length;
}
}
7.2 代码解析
-
构造函数:
WordCountTask
类的构造函数接收一个字符串数组fc
和两个整数start
和end
,表示当前任务需要处理的文件内容的范围。 -
compute()
方法:这是RecursiveTask
的核心方法,用于定义任务的计算逻辑。- 如果
end - start <= 1
,表示当前任务只需要处理一行数据,直接调用countWords()
方法统计该行中的单词数量。 - 否则,将任务一分为二,分别创建两个子任务
mr1
和mr2
,并递归调用compute()
方法处理后半部分数据,同时使用fork()
方法异步处理前半部分数据。 - 最后,使用
join()
方法等待前半部分任务的结果,并将两个子任务的结果相加,得到最终的单词数量。
- 如果
-
countWords()
方法:用于统计一行数据中的单词数量。通过split(" ")
方法将字符串按空格分割成单词数组,并返回数组的长度。
7.3 运行示例
假设我们有一个包含多行文本的字符串数组 fc
,我们可以使用以下代码来统计整个文件中的单词数量:
import java.util.concurrent.ForkJoinPool;
public class WordCountDemo {
public static void main(String[] args) {
String[] fc = {
"Hello world",
"This is a test",
"ForkJoinPool is powerful",
"MapReduce is a distributed computing framework"
};
ForkJoinPool pool = new ForkJoinPool();
WordCountTask task = new WordCountTask(fc, 0, fc.length);
int result = pool.invoke(task);
System.out.println("Total words: " + result);
}
}
运行上述代码,输出结果为:
Total words: 15
8 总结
Fork/Join 并行计算框架主要解决的是分治任务。它通过任务窃取机制来提高线程的利用率,适用于处理需要递归分解和合并结果的任务。Java 1.8 提供的 Stream API 中的并行流也是基于 ForkJoinPool 实现的,但需要注意不同类型的计算任务可能会影响系统的性能,建议使用不同的 ForkJoinPool 执行不同类型的计算任务。
通过本文的介绍,希望读者能够更好地理解 Fork/Join 并行计算框架的工作原理及其在实际应用中的优势。
9 思维导图
10 参考链接
深入理解Java并发编程之Fork/Join框架