ForkJoinPool 探索

介绍

“分而治之“是理清思路和解决问题的一个重要的方法。大到系统架构对功能模块的拆分,小到归并排序的实现,无一不在散发着分而治之的思想。在实现分而治之的算法的时候,我们通常使用递归的方法。递归相当于把大的任务拆成多个小的任务,然后大任务等待多个小的子任务执行完成后,合并子任务的结果。一般来说,父任务依赖与子任务的执行结果,子任务与子任务之间没有依赖关系。因此子任务之间可以并发执行来提升性能。于是ForkJoinPool提供了一个并发处理“分而治之”的框架,让我们能以类似于递归的编程方式获得并发执行的能力。

使用

分而治之代码典型的形式如下:

Result solve(Problem problem) {
    if (problem is small) {
        directly solve problem
    } else {
        split problem into independent parts
        fork new subtasks to solve each part
        join all subtasks
        compose result from subresults
    }
}

计算斐波那契数:

Class Fibonacci extends RecursiveTask<Integer> {
   final int n;
   Fibonacci(int n) { this.n = n; }
   Integer compute() {
     if (n <= 1)
       return n;
     Fibonacci f1 = new Fibonacci(n - 1);
     f1.fork();
     Fibonacci f2 = new Fibonacci(n - 2);
     return f2.compute() + f1.join();
   }
}

原理

ForkJoinPool的核心在于其轻量级的调度机制,采用了Cilk的work-stealing的基本调度策略:

  • 每个工作线程维持一个任务队列
  • 任务队列以双端队列的形式维护,不仅支持先进后出的pushpop操作,还支持先进先出的take操作
  • 由父任务fork出来的子任务被push到运行该父任务的工作线程对应的任务队列中
  • 工作线程以先进后出的方式处理pop自己任务队列中的任务(优先处理最年轻的任务)
  • 当任务队列中没有任务时,工作线程尝试随机从其他任务队列中窃取任务
  • 当工作线程没有任务可以执行,且窃取不到任务时,它会“退出”(yiled、sleep、优先级调整),经过一段时间后再次尝试。除非其他所有的线程也都没有任务可以执行,这种情况下它们会一直阻塞直到有新的任务从上层添加进来

一个简单的实现:

public class NaiveForkJoinPool {
    private final TaskQueue[] submissionQueues;
    private final TaskQueue[] workerQueues;
    private final WorkerThread[] workers;
    private final AtomicInteger aliveCount;
    private final ReentrantLock lock = new ReentrantLock();
    private final Condition taskEmpty = lock.newCondition();
    private final int parallelism;

    public NaiveForkJoinPool(int parallelism) {
        this.parallelism = parallelism;
        submissionQueues = new TaskQueue[parallelism];
        workerQueues = new TaskQueue[parallelism];
        workers = new WorkerThread[parallelism];
        aliveCount = new AtomicInteger(parallelism);

        for (int i = 0; i < parallelism; i++) {
            submissionQueues[i] = new TaskQueue();
            workerQueues[i] = new TaskQueue();
            workers[i] = new WorkerThread(this, workerQueues[i]);
        }

        for (int i = 0; i < parallelism; i++) {
            workers[i].start();
        }
    }

    public <T> T invoke(Task<T> task) {
        TaskQueue sd = submissionQueues[(submissionQueues.length
                - 1) & ThreadLocalRandom.current().nextInt()];
        sd.push(task);
        tryCompensate();
        return task.join();
    }

    public <T> List<T> invokeAll(Task<T>... tasks) {
        List<T> res = new LinkedList<>();
        for (Task<T> task : tasks) {
            TaskQueue sd = submissionQueues[(submissionQueues.length
                    - 1) & ThreadLocalRandom.current().nextInt()];
            sd.push(task);
            tryCompensate();
            res.add(task.join());
        }
        return res;
    }

    void tryCompensate() {
        if (aliveCount.get() < parallelism) {
            lock.lock();
            if (aliveCount.get() < parallelism) {
                taskEmpty.signal();
            }
            lock.unlock();
        }
    }

    void runWorker() {
        int len = submissionQueues.length;
        int startIndex = (ThreadLocalRandom.current().nextInt()) & (len -
                1);
        for (Task task = null; ; ) {
            if (task != null || (task = scan(startIndex)) != null) {
                task.runTask();
                task = null;
            } else {
                task = awaitForWork(startIndex);
            }
        }
    }

    Task scan(int startIndex) {
        Task task;
        if ((task = scan(startIndex, submissionQueues)) != null) {
            return task;
        }
        if ((task = scan(startIndex, workerQueues)) != null) {
            return task;
        }
        return null;
    }

    Task scan(int startIndex, TaskQueue[] queues) {
        for (int i = startIndex, len = queues.length; i <
                startIndex + len; i++) {
            TaskQueue td = queues[i & (len - 1)];
            Task task = td.take();
            if (task != null) {
                return task;
            }
        }
        return null;
    }

    Task awaitForWork(int startIndex) {
        lock.lock();
        try {
            Task task = scan(startIndex);
            if (task != null) {
                return task;
            }
            aliveCount.decrementAndGet();
            try {
                taskEmpty.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            aliveCount.incrementAndGet();
            return null;
        } finally {
            lock.unlock();
        }
    }

    class WorkerThread extends Thread {
        NaiveForkJoinPool pool;
        TaskQueue workQueue;

        public WorkerThread(NaiveForkJoinPool pool, TaskQueue workQueue) {
            this.pool = pool;
            this.workQueue = workQueue;
        }

        @Override
        public void run() {
            runWorker();
        }
    }

    static abstract class Task<T> {
        static final int NORMAL = 1;
        final AtomicInteger status = new AtomicInteger();
        final CountDownLatch isDone = new CountDownLatch(1);
        private T result;


        public abstract T compute();

        public void runTask() {
            result = compute();
            status.set(NORMAL);
            isDone.countDown();
        }

        public Task<T> fork() {
            WorkerThread t = (WorkerThread) Thread.currentThread();
            t.workQueue.push(this);
            t.pool.tryCompensate();
            return this;
        }

        public T join() {
            Thread currentThread = Thread.currentThread();

            if (currentThread instanceof WorkerThread) {
                WorkerThread t = (WorkerThread) Thread.currentThread();
                TaskQueue wk = t.workQueue;
                for (Task task = wk.pop(); task != null; task = wk.pop()) {
                    task.runTask();
                    if (task == this) {
                        return result;
                    }
                }
                waitForComplete();
            } else {
                waitForComplete();
            }
            return result;
        }

        void waitForComplete() {
            try {
                isDone.await();
            } catch (InterruptedException e) {
            }
        }
    }

    static class TaskQueue {
        private final Deque<Task> deque = new ConcurrentLinkedDeque<>();

        public void push(Task task) {
            deque.push(task);
        }

        public Task pop() {
            return deque.pollFirst();
        }

        public Task take() {
            return deque.pollLast();

        }
    }
}

参考资料:

原创文章,作者:ItWorker,如若转载,请注明出处:https://blog.ytso.com/tech/pnotes/60460.html

(0)
上一篇 2021年8月10日 17:45
下一篇 2021年8月10日 17:46

相关推荐

发表回复

登录后才能评论