第14章
🔀 Fork/Join框架
掌握Java并行计算框架,理解分治算法和工作窃取机制
学习目标
- 理解Fork/Join框架的设计思想
- 掌握ForkJoinPool的使用
- 学会编写RecursiveTask
- 了解工作窃取算法
- 应用Fork/Join解决实际问题
Fork/Join框架概述
Fork/Join框架是Java 7引入的一个并行计算框架,专门用于解决可以递归分解的计算密集型任务。它基于分治算法的思想,将大任务分解为小任务并行执行,最后合并结果。
核心思想
Fork/Join框架的核心是"分而治之":Fork负责将大任务分解为小任务,Join负责合并小任务的结果。
框架特点
分治算法
将复杂问题分解为相似的子问题,递归求解后合并结果。
工作窃取
空闲线程可以从其他线程的任务队列中窃取任务,提高CPU利用率。
高性能
充分利用多核CPU,适合计算密集型任务的并行处理。
适用场景
- 递归算法:如归并排序、快速排序等
- 数学计算:如矩阵运算、数值积分等
- 数据处理:如大数据集的并行处理
- 图像处理:如图像滤波、变换等
ForkJoinPool详解
ForkJoinPool是Fork/Join框架的核心,它是一个特殊的线程池,专门用于执行ForkJoinTask。与普通线程池不同,ForkJoinPool实现了工作窃取算法。
创建ForkJoinPool
创建ForkJoinPool示例
// 使用默认并行度(CPU核心数)
ForkJoinPool pool = new ForkJoinPool();
// 指定并行度
ForkJoinPool customPool = new ForkJoinPool(4);
// 使用公共池(推荐)
ForkJoinPool commonPool = ForkJoinPool.commonPool();
// 获取并行度
int parallelism = pool.getParallelism();
System.out.println("并行度: " + parallelism);
工作窃取算法
双端队列
每个工作线程都有自己的双端队列(Deque),新任务从队列头部添加,线程从头部取任务执行。
任务窃取
当线程的队列为空时,它会随机选择其他线程的队列,从尾部窃取任务执行。
负载均衡
通过工作窃取机制,自动实现负载均衡,避免某些线程空闲而其他线程过载。
RecursiveTask和RecursiveAction
Fork/Join框架提供了两个抽象类来定义任务:RecursiveTask(有返回值)和RecursiveAction(无返回值)。
RecursiveTask示例
计算数组和的RecursiveTask
public class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 1000; // 阈值
private int[] array;
private int start;
private int end;
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
// 如果任务足够小,直接计算
if (end - start <= THRESHOLD) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
// 任务太大,分解为子任务
int mid = (start + end) / 2;
SumTask leftTask = new SumTask(array, start, mid);
SumTask rightTask = new SumTask(array, mid, end);
// Fork:异步执行左任务
leftTask.fork();
// 执行右任务(当前线程)
Long rightResult = rightTask.compute();
// Join:等待左任务完成并获取结果
Long leftResult = leftTask.join();
// 合并结果
return leftResult + rightResult;
}
}
使用示例
使用SumTask计算数组和
public class ForkJoinExample {
public static void main(String[] args) {
// 创建大数组
int[] array = new int[10_000_000];
for (int i = 0; i < array.length; i++) {
array[i] = i + 1;
}
// 使用Fork/Join计算
ForkJoinPool pool = ForkJoinPool.commonPool();
SumTask task = new SumTask(array, 0, array.length);
long startTime = System.currentTimeMillis();
Long result = pool.invoke(task);
long endTime = System.currentTimeMillis();
System.out.println("结果: " + result);
System.out.println("耗时: " + (endTime - startTime) + "ms");
// 对比串行计算
startTime = System.currentTimeMillis();
long serialSum = 0;
for (int value : array) {
serialSum += value;
}
endTime = System.currentTimeMillis();
System.out.println("串行结果: " + serialSum);
System.out.println("串行耗时: " + (endTime - startTime) + "ms");
}
}
注意事项
- 合理设置阈值,避免任务分解过细导致开销过大
- 优先使用ForkJoinPool.commonPool(),避免创建过多线程池
- 确保任务是CPU密集型的,I/O密集型任务不适合使用Fork/Join
实际应用案例
并行归并排序
MergeSortTask实现
public class MergeSortTask extends RecursiveAction {
private static final int THRESHOLD = 100;
private int[] array;
private int start;
private int end;
public MergeSortTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
if (end - start <= THRESHOLD) {
// 使用插入排序处理小数组
insertionSort(array, start, end);
return;
}
int mid = (start + end) / 2;
MergeSortTask leftTask = new MergeSortTask(array, start, mid);
MergeSortTask rightTask = new MergeSortTask(array, mid, end);
// 并行执行两个子任务
invokeAll(leftTask, rightTask);
// 合并已排序的两部分
merge(array, start, mid, end);
}
private void insertionSort(int[] arr, int start, int end) {
for (int i = start + 1; i < end; i++) {
int key = arr[i];
int j = i - 1;
while (j >= start && arr[j] > key) {
arr[j + 1] = arr[j];
j--;
}
arr[j + 1] = key;
}
}
private void merge(int[] arr, int start, int mid, int end) {
int[] temp = new int[end - start];
int i = start, j = mid, k = 0;
while (i < mid && j < end) {
if (arr[i] <= arr[j]) {
temp[k++] = arr[i++];
} else {
temp[k++] = arr[j++];
}
}
while (i < mid) temp[k++] = arr[i++];
while (j < end) temp[k++] = arr[j++];
System.arraycopy(temp, 0, arr, start, temp.length);
}
}
矩阵乘法
并行矩阵乘法
public class MatrixMultiplyTask extends RecursiveTask<int[][]> {
private static final int THRESHOLD = 64;
private int[][] a, b;
private int startRow, endRow, startCol, endCol;
public MatrixMultiplyTask(int[][] a, int[][] b,
int startRow, int endRow,
int startCol, int endCol) {
this.a = a;
this.b = b;
this.startRow = startRow;
this.endRow = endRow;
this.startCol = startCol;
this.endCol = endCol;
}
@Override
protected int[][] compute() {
int rows = endRow - startRow;
int cols = endCol - startCol;
if (rows <= THRESHOLD && cols <= THRESHOLD) {
return multiplyDirect();
}
// 分解任务
if (rows >= cols) {
int midRow = (startRow + endRow) / 2;
MatrixMultiplyTask task1 = new MatrixMultiplyTask(
a, b, startRow, midRow, startCol, endCol);
MatrixMultiplyTask task2 = new MatrixMultiplyTask(
a, b, midRow, endRow, startCol, endCol);
task1.fork();
int[][] result2 = task2.compute();
int[][] result1 = task1.join();
return combineResults(result1, result2, true);
} else {
int midCol = (startCol + endCol) / 2;
MatrixMultiplyTask task1 = new MatrixMultiplyTask(
a, b, startRow, endRow, startCol, midCol);
MatrixMultiplyTask task2 = new MatrixMultiplyTask(
a, b, startRow, endRow, midCol, endCol);
task1.fork();
int[][] result2 = task2.compute();
int[][] result1 = task1.join();
return combineResults(result1, result2, false);
}
}
private int[][] multiplyDirect() {
int rows = endRow - startRow;
int cols = endCol - startCol;
int[][] result = new int[rows][cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
for (int k = 0; k < a[0].length; k++) {
result[i][j] += a[startRow + i][k] * b[k][startCol + j];
}
}
}
return result;
}
private int[][] combineResults(int[][] result1, int[][] result2, boolean byRow) {
if (byRow) {
int[][] combined = new int[result1.length + result2.length][result1[0].length];
System.arraycopy(result1, 0, combined, 0, result1.length);
System.arraycopy(result2, 0, combined, result1.length, result2.length);
return combined;
} else {
int[][] combined = new int[result1.length][result1[0].length + result2[0].length];
for (int i = 0; i < result1.length; i++) {
System.arraycopy(result1[i], 0, combined[i], 0, result1[i].length);
System.arraycopy(result2[i], 0, combined[i], result1[i].length, result2[i].length);
}
return combined;
}
}
}
性能优化和最佳实践
性能优化技巧
合理设置阈值
阈值太小会导致任务分解过细,增加调度开销;太大则无法充分利用并行性。
减少内存分配
尽量重用对象,避免在compute()方法中频繁创建新对象。
避免深度递归
控制递归深度,防止栈溢出,可以考虑混合使用递归和迭代。
最佳实践
实践建议
- 使用公共池:优先使用ForkJoinPool.commonPool(),避免创建多个线程池
- CPU密集型任务:Fork/Join适合CPU密集型任务,不适合I/O密集型任务
- 任务粒度:确保任务足够大以抵消并行开销,但不要太大以至于无法并行
- 避免阻塞:在compute()方法中避免阻塞操作,如I/O或同步
- 异常处理:正确处理子任务中的异常,避免影响整个计算
性能对比示例
性能测试代码
public class PerformanceTest {
public static void main(String[] args) {
int[] array = new int[10_000_000];
Random random = new Random();
for (int i = 0; i < array.length; i++) {
array[i] = random.nextInt(1000);
}
// 串行计算
long startTime = System.currentTimeMillis();
long serialSum = Arrays.stream(array).sum();
long serialTime = System.currentTimeMillis() - startTime;
// 并行流计算
startTime = System.currentTimeMillis();
long parallelStreamSum = Arrays.stream(array).parallel().sum();
long parallelStreamTime = System.currentTimeMillis() - startTime;
// Fork/Join计算
ForkJoinPool pool = ForkJoinPool.commonPool();
SumTask task = new SumTask(array, 0, array.length);
startTime = System.currentTimeMillis();
long forkJoinSum = pool.invoke(task);
long forkJoinTime = System.currentTimeMillis() - startTime;
System.out.println("串行计算: " + serialSum + ", 耗时: " + serialTime + "ms");
System.out.println("并行流: " + parallelStreamSum + ", 耗时: " + parallelStreamTime + "ms");
System.out.println("Fork/Join: " + forkJoinSum + ", 耗时: " + forkJoinTime + "ms");
System.out.println("\n加速比:");
System.out.println("并行流 vs 串行: " + (double)serialTime / parallelStreamTime);
System.out.println("Fork/Join vs 串行: " + (double)serialTime / forkJoinTime);
}
}