介绍
“分而治之“是理清思路和解决问题的一个重要的方法。大到系统架构对功能模块的拆分,小到归并排序的实现,无一不在散发着分而治之的思想。在实现分而治之的算法的时候,我们通常使用递归的方法。递归相当于把大的任务拆成多个小的任务,然后大任务等待多个小的子任务执行完成后,合并子任务的结果。一般来说,父任务依赖与子任务的执行结果,子任务与子任务之间没有依赖关系。因此子任务之间可以并发执行来提升性能。于是ForkJoinPool
提供了一个并发处理“分而治之”的框架,让我们能以类似于递归的编程方式获得并发执行的能力。
使用
分而治之代码典型的形式如下:
Result solve(Problem problem) {
if (problem is small) {
directly solve problem
} else {
split problem into independent parts
fork new subtasks to solve each part
join all subtasks
compose result from subresults
}
}
计算斐波那契数:
Class Fibonacci extends RecursiveTask<Integer> {
final int n;
Fibonacci(int n) { this.n = n; }
Integer compute() {
if (n <= 1)
return n;
Fibonacci f1 = new Fibonacci(n - 1);
f1.fork();
Fibonacci f2 = new Fibonacci(n - 2);
return f2.compute() + f1.join();
}
}
原理
ForkJoinPool
的核心在于其轻量级的调度机制,采用了Cilk的work-stealing的基本调度策略:
- 每个工作线程维持一个任务队列
- 任务队列以双端队列的形式维护,不仅支持先进后出的
push
和pop
操作,还支持先进先出的take操作 - 由父任务
fork
出来的子任务被push
到运行该父任务的工作线程对应的任务队列中 - 工作线程以先进后出的方式处理
pop
自己任务队列中的任务(优先处理最年轻的任务) - 当任务队列中没有任务时,工作线程尝试随机从其他任务队列中窃取任务
- 当工作线程没有任务可以执行,且窃取不到任务时,它会“退出”(yiled、sleep、优先级调整),经过一段时间后再次尝试。除非其他所有的线程也都没有任务可以执行,这种情况下它们会一直阻塞直到有新的任务从上层添加进来
一个简单的实现:
public class NaiveForkJoinPool {
private final TaskQueue[] submissionQueues;
private final TaskQueue[] workerQueues;
private final WorkerThread[] workers;
private final AtomicInteger aliveCount;
private final ReentrantLock lock = new ReentrantLock();
private final Condition taskEmpty = lock.newCondition();
private final int parallelism;
public NaiveForkJoinPool(int parallelism) {
this.parallelism = parallelism;
submissionQueues = new TaskQueue[parallelism];
workerQueues = new TaskQueue[parallelism];
workers = new WorkerThread[parallelism];
aliveCount = new AtomicInteger(parallelism);
for (int i = 0; i < parallelism; i++) {
submissionQueues[i] = new TaskQueue();
workerQueues[i] = new TaskQueue();
workers[i] = new WorkerThread(this, workerQueues[i]);
}
for (int i = 0; i < parallelism; i++) {
workers[i].start();
}
}
public <T> T invoke(Task<T> task) {
TaskQueue sd = submissionQueues[(submissionQueues.length
- 1) & ThreadLocalRandom.current().nextInt()];
sd.push(task);
tryCompensate();
return task.join();
}
public <T> List<T> invokeAll(Task<T>... tasks) {
List<T> res = new LinkedList<>();
for (Task<T> task : tasks) {
TaskQueue sd = submissionQueues[(submissionQueues.length
- 1) & ThreadLocalRandom.current().nextInt()];
sd.push(task);
tryCompensate();
res.add(task.join());
}
return res;
}
void tryCompensate() {
if (aliveCount.get() < parallelism) {
lock.lock();
if (aliveCount.get() < parallelism) {
taskEmpty.signal();
}
lock.unlock();
}
}
void runWorker() {
int len = submissionQueues.length;
int startIndex = (ThreadLocalRandom.current().nextInt()) & (len -
1);
for (Task task = null; ; ) {
if (task != null || (task = scan(startIndex)) != null) {
task.runTask();
task = null;
} else {
task = awaitForWork(startIndex);
}
}
}
Task scan(int startIndex) {
Task task;
if ((task = scan(startIndex, submissionQueues)) != null) {
return task;
}
if ((task = scan(startIndex, workerQueues)) != null) {
return task;
}
return null;
}
Task scan(int startIndex, TaskQueue[] queues) {
for (int i = startIndex, len = queues.length; i <
startIndex + len; i++) {
TaskQueue td = queues[i & (len - 1)];
Task task = td.take();
if (task != null) {
return task;
}
}
return null;
}
Task awaitForWork(int startIndex) {
lock.lock();
try {
Task task = scan(startIndex);
if (task != null) {
return task;
}
aliveCount.decrementAndGet();
try {
taskEmpty.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
aliveCount.incrementAndGet();
return null;
} finally {
lock.unlock();
}
}
class WorkerThread extends Thread {
NaiveForkJoinPool pool;
TaskQueue workQueue;
public WorkerThread(NaiveForkJoinPool pool, TaskQueue workQueue) {
this.pool = pool;
this.workQueue = workQueue;
}
@Override
public void run() {
runWorker();
}
}
static abstract class Task<T> {
static final int NORMAL = 1;
final AtomicInteger status = new AtomicInteger();
final CountDownLatch isDone = new CountDownLatch(1);
private T result;
public abstract T compute();
public void runTask() {
result = compute();
status.set(NORMAL);
isDone.countDown();
}
public Task<T> fork() {
WorkerThread t = (WorkerThread) Thread.currentThread();
t.workQueue.push(this);
t.pool.tryCompensate();
return this;
}
public T join() {
Thread currentThread = Thread.currentThread();
if (currentThread instanceof WorkerThread) {
WorkerThread t = (WorkerThread) Thread.currentThread();
TaskQueue wk = t.workQueue;
for (Task task = wk.pop(); task != null; task = wk.pop()) {
task.runTask();
if (task == this) {
return result;
}
}
waitForComplete();
} else {
waitForComplete();
}
return result;
}
void waitForComplete() {
try {
isDone.await();
} catch (InterruptedException e) {
}
}
}
static class TaskQueue {
private final Deque<Task> deque = new ConcurrentLinkedDeque<>();
public void push(Task task) {
deque.push(task);
}
public Task pop() {
return deque.pollFirst();
}
public Task take() {
return deque.pollLast();
}
}
}
参考资料:
原创文章,作者:ItWorker,如若转载,请注明出处:https://blog.ytso.com/tech/pnotes/60460.html