第14章

🔀 Fork/Join框架

掌握Java并行计算框架,理解分治算法和工作窃取机制

学习目标

Fork/Join框架概述

Fork/Join框架是Java 7引入的一个并行计算框架,专门用于解决可以递归分解的计算密集型任务。它基于分治算法的思想,将大任务分解为小任务并行执行,最后合并结果。

核心思想

Fork/Join框架的核心是"分而治之":Fork负责将大任务分解为小任务,Join负责合并小任务的结果。

框架特点

分治算法
将复杂问题分解为相似的子问题,递归求解后合并结果。
工作窃取
空闲线程可以从其他线程的任务队列中窃取任务,提高CPU利用率。
高性能
充分利用多核CPU,适合计算密集型任务的并行处理。

适用场景

ForkJoinPool详解

ForkJoinPool是Fork/Join框架的核心,它是一个特殊的线程池,专门用于执行ForkJoinTask。与普通线程池不同,ForkJoinPool实现了工作窃取算法。

创建ForkJoinPool

创建ForkJoinPool示例
// 使用默认并行度(CPU核心数)
ForkJoinPool pool = new ForkJoinPool();

// 指定并行度
ForkJoinPool customPool = new ForkJoinPool(4);

// 使用公共池(推荐)
ForkJoinPool commonPool = ForkJoinPool.commonPool();

// 获取并行度
int parallelism = pool.getParallelism();
System.out.println("并行度: " + parallelism);

工作窃取算法

双端队列
每个工作线程都有自己的双端队列(Deque),新任务从队列头部添加,线程从头部取任务执行。
任务窃取
当线程的队列为空时,它会随机选择其他线程的队列,从尾部窃取任务执行。
负载均衡
通过工作窃取机制,自动实现负载均衡,避免某些线程空闲而其他线程过载。

RecursiveTask和RecursiveAction

Fork/Join框架提供了两个抽象类来定义任务:RecursiveTask(有返回值)和RecursiveAction(无返回值)。

RecursiveTask示例

计算数组和的RecursiveTask
public class SumTask extends RecursiveTask<Long> {
    private static final int THRESHOLD = 1000; // 阈值
    private int[] array;
    private int start;
    private int end;
    
    public SumTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }
    
    @Override
    protected Long compute() {
        // 如果任务足够小,直接计算
        if (end - start <= THRESHOLD) {
            long sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            return sum;
        }
        
        // 任务太大,分解为子任务
        int mid = (start + end) / 2;
        SumTask leftTask = new SumTask(array, start, mid);
        SumTask rightTask = new SumTask(array, mid, end);
        
        // Fork:异步执行左任务
        leftTask.fork();
        
        // 执行右任务(当前线程)
        Long rightResult = rightTask.compute();
        
        // Join:等待左任务完成并获取结果
        Long leftResult = leftTask.join();
        
        // 合并结果
        return leftResult + rightResult;
    }
}

使用示例

使用SumTask计算数组和
public class ForkJoinExample {
    public static void main(String[] args) {
        // 创建大数组
        int[] array = new int[10_000_000];
        for (int i = 0; i < array.length; i++) {
            array[i] = i + 1;
        }
        
        // 使用Fork/Join计算
        ForkJoinPool pool = ForkJoinPool.commonPool();
        SumTask task = new SumTask(array, 0, array.length);
        
        long startTime = System.currentTimeMillis();
        Long result = pool.invoke(task);
        long endTime = System.currentTimeMillis();
        
        System.out.println("结果: " + result);
        System.out.println("耗时: " + (endTime - startTime) + "ms");
        
        // 对比串行计算
        startTime = System.currentTimeMillis();
        long serialSum = 0;
        for (int value : array) {
            serialSum += value;
        }
        endTime = System.currentTimeMillis();
        
        System.out.println("串行结果: " + serialSum);
        System.out.println("串行耗时: " + (endTime - startTime) + "ms");
    }
}
注意事项
  • 合理设置阈值,避免任务分解过细导致开销过大
  • 优先使用ForkJoinPool.commonPool(),避免创建过多线程池
  • 确保任务是CPU密集型的,I/O密集型任务不适合使用Fork/Join

实际应用案例

并行归并排序

MergeSortTask实现
public class MergeSortTask extends RecursiveAction {
    private static final int THRESHOLD = 100;
    private int[] array;
    private int start;
    private int end;
    
    public MergeSortTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }
    
    @Override
    protected void compute() {
        if (end - start <= THRESHOLD) {
            // 使用插入排序处理小数组
            insertionSort(array, start, end);
            return;
        }
        
        int mid = (start + end) / 2;
        MergeSortTask leftTask = new MergeSortTask(array, start, mid);
        MergeSortTask rightTask = new MergeSortTask(array, mid, end);
        
        // 并行执行两个子任务
        invokeAll(leftTask, rightTask);
        
        // 合并已排序的两部分
        merge(array, start, mid, end);
    }
    
    private void insertionSort(int[] arr, int start, int end) {
        for (int i = start + 1; i < end; i++) {
            int key = arr[i];
            int j = i - 1;
            while (j >= start && arr[j] > key) {
                arr[j + 1] = arr[j];
                j--;
            }
            arr[j + 1] = key;
        }
    }
    
    private void merge(int[] arr, int start, int mid, int end) {
        int[] temp = new int[end - start];
        int i = start, j = mid, k = 0;
        
        while (i < mid && j < end) {
            if (arr[i] <= arr[j]) {
                temp[k++] = arr[i++];
            } else {
                temp[k++] = arr[j++];
            }
        }
        
        while (i < mid) temp[k++] = arr[i++];
        while (j < end) temp[k++] = arr[j++];
        
        System.arraycopy(temp, 0, arr, start, temp.length);
    }
}

矩阵乘法

并行矩阵乘法
public class MatrixMultiplyTask extends RecursiveTask<int[][]> {
    private static final int THRESHOLD = 64;
    private int[][] a, b;
    private int startRow, endRow, startCol, endCol;
    
    public MatrixMultiplyTask(int[][] a, int[][] b, 
                             int startRow, int endRow, 
                             int startCol, int endCol) {
        this.a = a;
        this.b = b;
        this.startRow = startRow;
        this.endRow = endRow;
        this.startCol = startCol;
        this.endCol = endCol;
    }
    
    @Override
    protected int[][] compute() {
        int rows = endRow - startRow;
        int cols = endCol - startCol;
        
        if (rows <= THRESHOLD && cols <= THRESHOLD) {
            return multiplyDirect();
        }
        
        // 分解任务
        if (rows >= cols) {
            int midRow = (startRow + endRow) / 2;
            MatrixMultiplyTask task1 = new MatrixMultiplyTask(
                a, b, startRow, midRow, startCol, endCol);
            MatrixMultiplyTask task2 = new MatrixMultiplyTask(
                a, b, midRow, endRow, startCol, endCol);
            
            task1.fork();
            int[][] result2 = task2.compute();
            int[][] result1 = task1.join();
            
            return combineResults(result1, result2, true);
        } else {
            int midCol = (startCol + endCol) / 2;
            MatrixMultiplyTask task1 = new MatrixMultiplyTask(
                a, b, startRow, endRow, startCol, midCol);
            MatrixMultiplyTask task2 = new MatrixMultiplyTask(
                a, b, startRow, endRow, midCol, endCol);
            
            task1.fork();
            int[][] result2 = task2.compute();
            int[][] result1 = task1.join();
            
            return combineResults(result1, result2, false);
        }
    }
    
    private int[][] multiplyDirect() {
        int rows = endRow - startRow;
        int cols = endCol - startCol;
        int[][] result = new int[rows][cols];
        
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                for (int k = 0; k < a[0].length; k++) {
                    result[i][j] += a[startRow + i][k] * b[k][startCol + j];
                }
            }
        }
        
        return result;
    }
    
    private int[][] combineResults(int[][] result1, int[][] result2, boolean byRow) {
        if (byRow) {
            int[][] combined = new int[result1.length + result2.length][result1[0].length];
            System.arraycopy(result1, 0, combined, 0, result1.length);
            System.arraycopy(result2, 0, combined, result1.length, result2.length);
            return combined;
        } else {
            int[][] combined = new int[result1.length][result1[0].length + result2[0].length];
            for (int i = 0; i < result1.length; i++) {
                System.arraycopy(result1[i], 0, combined[i], 0, result1[i].length);
                System.arraycopy(result2[i], 0, combined[i], result1[i].length, result2[i].length);
            }
            return combined;
        }
    }
}

性能优化和最佳实践

性能优化技巧

合理设置阈值
阈值太小会导致任务分解过细,增加调度开销;太大则无法充分利用并行性。
减少内存分配
尽量重用对象,避免在compute()方法中频繁创建新对象。
避免深度递归
控制递归深度,防止栈溢出,可以考虑混合使用递归和迭代。

最佳实践

实践建议
  • 使用公共池:优先使用ForkJoinPool.commonPool(),避免创建多个线程池
  • CPU密集型任务:Fork/Join适合CPU密集型任务,不适合I/O密集型任务
  • 任务粒度:确保任务足够大以抵消并行开销,但不要太大以至于无法并行
  • 避免阻塞:在compute()方法中避免阻塞操作,如I/O或同步
  • 异常处理:正确处理子任务中的异常,避免影响整个计算

性能对比示例

性能测试代码
public class PerformanceTest {
    public static void main(String[] args) {
        int[] array = new int[10_000_000];
        Random random = new Random();
        for (int i = 0; i < array.length; i++) {
            array[i] = random.nextInt(1000);
        }
        
        // 串行计算
        long startTime = System.currentTimeMillis();
        long serialSum = Arrays.stream(array).sum();
        long serialTime = System.currentTimeMillis() - startTime;
        
        // 并行流计算
        startTime = System.currentTimeMillis();
        long parallelStreamSum = Arrays.stream(array).parallel().sum();
        long parallelStreamTime = System.currentTimeMillis() - startTime;
        
        // Fork/Join计算
        ForkJoinPool pool = ForkJoinPool.commonPool();
        SumTask task = new SumTask(array, 0, array.length);
        startTime = System.currentTimeMillis();
        long forkJoinSum = pool.invoke(task);
        long forkJoinTime = System.currentTimeMillis() - startTime;
        
        System.out.println("串行计算: " + serialSum + ", 耗时: " + serialTime + "ms");
        System.out.println("并行流: " + parallelStreamSum + ", 耗时: " + parallelStreamTime + "ms");
        System.out.println("Fork/Join: " + forkJoinSum + ", 耗时: " + forkJoinTime + "ms");
        
        System.out.println("\n加速比:");
        System.out.println("并行流 vs 串行: " + (double)serialTime / parallelStreamTime);
        System.out.println("Fork/Join vs 串行: " + (double)serialTime / forkJoinTime);
    }
}
上一章:同步工具类 返回目录 下一章:ThreadLocal详解