这里就JUC包中的CountDownLatch类做相关介绍
概述 JUC包中的CountDownLatch类是一个同步工具类,可实现线程间的通信。其典型方法如下所示
1 2 3 4 5 6 7 8 9 10 11 public CountDownLatch (int count) ;public void await () throws InterruptedException;public boolean await (long timeout, TimeUnit unit) ;public void countDown () ;
基本使用方法也很简单。首先创建一个指定计数器值的CountDownLatch实例,每当其他线程完成任务时就通过countDown方法将计数器值减1。这样当计数器的值为0时,之前由于调用await方法而被阻塞的线程就会结束等待,恢复执行
实践 CountDownLatch的典型应用场景,大体可分为两类:结束信号、开始信号
结束信号 主线程创建、启动N个异步任务,我们期望当这N个任务全部执行完毕结束后,主线程才可以继续往下执行。即将CountDownLatch作为任务的结束信号来使用。示例代码如下所示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 public class CountDownLatchTest1 { @Test public void test1 () throws InterruptedException { ExecutorService threadPool = Executors.newFixedThreadPool(5 ); CountDownLatch doneSignal = new CountDownLatch (3 ); Arrays.asList("Task 1" ,"Task 2" ,"Task 3" ) .stream() .map( name -> new Task (name, doneSignal) ) .forEach( task -> threadPool.execute(task) ); doneSignal.await(); System.out.println("所有任务均完成" ); } @AllArgsConstructor private static class Task implements Runnable { private String taskName; private CountDownLatch doneSignal; @Override public void run () { System.out.println(taskName + " 开始" ); try { Thread.sleep( RandomUtils.nextInt(5 ,9 ) * 1000 ); }catch (Exception e) { System.out.println( "Happen Exception: " + e.getMessage()); } System.out.println(taskName + " 完成" ); doneSignal.countDown(); } } }
测试结果如下所示,符合预期
开始信号 主线程创建N个异步任务,但这N个任务不能立即开始执行。而需要等待某个共同的前置任务(比如初始化任务)完成后,才允许这N个任务开始执行。即将CountDownLatch作为任务的开始信号来使用。示例代码如下所示
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 public class CountDownLatchTest2 { @Test public void test1 () throws InterruptedException { ExecutorService threadPool = Executors.newFixedThreadPool(10 ); CountDownLatch startSignal = new CountDownLatch (1 ); Arrays.asList("Task 1" ,"Task 2" ,"Task 3" ) .stream() .map( name -> new Task (name, startSignal) ) .forEach( task -> threadPool.execute(task) ); System.out.println("初始化准备工作开始" ); try { Thread.sleep( RandomUtils.nextInt(5 ,9 ) * 1000 ); }catch (Exception e) { System.out.println( "Happen Exception: " + e.getMessage()); } System.out.println("初始化准备工作结束" ); startSignal.countDown(); try { Thread.sleep( 20 *1000 ); } catch (Exception e) {} System.out.println("Game Over" ); } @AllArgsConstructor private static class Task implements Runnable { private String taskName; private CountDownLatch startSignal; @Override public void run () { try { startSignal.await(); }catch (InterruptedException e) { System.out.println( "Happen Exception: " + e.getMessage()); } System.out.println(taskName + " 开始" ); try { Thread.sleep( RandomUtils.nextInt(5 ,9 ) * 1000 ); }catch (Exception e) { System.out.println( "Happen Exception: " + e.getMessage()); } System.out.println(taskName + " 完成" ); } } }
测试结果如下所示,符合预期
基本原理 构造器 CountDownLatch类实现过程同样依赖于AQS。在构建CountDownLatch实例过程时,一方面,通过sync变量持有AQS的实现类Sync;另一方面,通过AQS的state字段来存储计数器值
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 public class CountDownLatch { private final Sync sync; public CountDownLatch (int count) { if (count < 0 ) throw new IllegalArgumentException ("count < 0" ); this .sync = new Sync (count); } private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } } } ... public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java .io.Serializable { private volatile int state; protected final void setState (int newState) { state = newState; } }
await方法 首先来看CountDownLatch的await方法。其委托sync调用AQS的acquireSharedInterruptibly方法,从方法名也可以看到其是对AQS中共享锁的使用。并根据当前计数器的值是否为0,来判断该线程是继续执行还是应该被阻塞。可以看到事实上AQS只是定义了是否需要阻塞线程的tryAcquireShared方法,具体的规则需要CountDownLatch类来进行实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 public class CountDownLatch { public void await () throws InterruptedException { sync.acquireSharedInterruptibly(1 ); } private static final class Sync extends AbstractQueuedSynchronizer { protected int tryAcquireShared (int acquires) { return (getState() == 0 ) ? 1 : -1 ; } } } ... public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java .io.Serializable { public final void acquireSharedInterruptibly (int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException (); if (tryAcquireShared(arg) < 0 ) doAcquireSharedInterruptibly(arg); } protected int tryAcquireShared (int arg) { throw new UnsupportedOperationException (); } }
当tryAcquireShared方法结果小于0时,即当前计数器不为0时,AQS如何通过doAcquireSharedInterruptibly方法实现阻塞呢?结合相关源码可以看到,首先通过addWaiter方法将当前线程包装为一个node实例,并将其加入AQS队列。在入队过程中需要注意,如果队列为空则其并不是直接将该node实例加入队列。而是先构造一个哨兵节点来入队,然后在enq方法下一轮for循环才将该node实例加入队列
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java .io.Serializable { private void doAcquireSharedInterruptibly (int arg) throws InterruptedException { final Node node = addWaiter(Node.SHARED); boolean failed = true ; try { for (;;) { final Node p = node.predecessor(); if (p == head) { int r = tryAcquireShared(arg); if (r >= 0 ) { setHeadAndPropagate(node, r); p.next = null ; failed = false ; return ; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException (); } } finally { if (failed) cancelAcquire(node); } } private Node addWaiter (Node mode) { Node node = new Node (Thread.currentThread(), mode); Node pred = tail; if (pred != null ) { node.prev = pred; if (compareAndSetTail(pred, node)) { pred.next = node; return node; } } enq(node); return node; } private Node enq (final Node node) { for (;;) { Node t = tail; if (t == null ) { if (compareAndSetHead(new Node ())) tail = head; } else { node.prev = t; if (compareAndSetTail(t, node)) { t.next = node; return t; } } } } }
然后通过shouldParkAfterFailedAcquire方法修改前驱节点的waitStatus。如果前驱节点的waitStatus字段是初始值0的话,需在第一轮for循环中进入shouldParkAfterFailedAcquire方法时,通过compareAndSetWaitStatus(pred, ws, Node.SIGNAL)方法将前驱节点的waitStatus字段修改为Node.SIGNAL(即-1)。这样在开始下一轮for循环时,shouldParkAfterFailedAcquire方法即会返回true。进而执行parkAndCheckInterrupt方法,利用LockSupport.park完成线程阻塞
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 private static boolean shouldParkAfterFailedAcquire (Node pred, Node node) { int ws = pred.waitStatus; if (ws == Node.SIGNAL) return true ; if (ws > 0 ) { do { node.prev = pred = pred.prev; } while (pred.waitStatus > 0 ); pred.next = node; } else { compareAndSetWaitStatus(pred, ws, Node.SIGNAL); } return false ; } private final boolean parkAndCheckInterrupt () { LockSupport.park(this ); return Thread.interrupted(); }
countDown方法 CountDownLatch的countDown方法类似。其同样是委托sync调用AQS的releaseShared方法。然后AQS执行tryReleaseShared方法,CountDownLatch类负责实现具体的规则逻辑。如果自减后当前计数器为0,则说明需要唤醒之前通过await方法而被阻塞的线程。然后通过AQS的doReleaseShared方法实现唤醒。具体地,其是从头节点的后继节点开始唤醒。因为前面已经说过,AQS队列的第一个节点(即头节点)只是一个哨兵节点
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 public class CountDownLatch { public void countDown () { sync.releaseShared(1 ); } private static final class Sync extends AbstractQueuedSynchronizer { protected boolean tryReleaseShared (int releases) { for (;;) { int c = getState(); if (c == 0 ) return false ; int nextc = c-1 ; if (compareAndSetState(c, nextc)) return nextc == 0 ; } } } } ... public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java .io.Serializable { public final boolean releaseShared (int arg) { if (tryReleaseShared(arg)) { doReleaseShared(); return true ; } return false ; } protected boolean tryReleaseShared (int arg) { throw new UnsupportedOperationException (); } private void doReleaseShared () { for (;;) { Node h = head; if (h != null && h != tail) { int ws = h.waitStatus; if (ws == Node.SIGNAL) { if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0 )) continue ; unparkSuccessor(h); } else if (ws == 0 && !compareAndSetWaitStatus(h, 0 , Node.PROPAGATE)) continue ; } if (h == head) break ; } } private void unparkSuccessor (Node node) { int ws = node.waitStatus; if (ws < 0 ) compareAndSetWaitStatus(node, ws, 0 ); Node s = node.next; if (s == null || s.waitStatus > 0 ) { s = null ; for (Node t = tail; t != null && t != node; t = t.prev) if (t.waitStatus <= 0 ) s = t; } if (s != null ) LockSupport.unpark(s.thread); } }
这里补充说明下,当上文由于调用await方法而被阻塞的线程唤醒后,其会在doAcquireSharedInterruptibly方法的for循环中恢复执行。此时由于tryAcquireShared方法的返回值r大于0满足条件,故其进入setHeadAndPropagate方法。在该方法中,其将自身重新设置为AQS的头节点。并通过doReleaseShared方法继续唤醒它的后继节点。从而实现将AQS队列被阻塞的线程全部唤醒
1 2 3 4 5 6 7 8 9 10 11 private void setHeadAndPropagate (Node node, int propagate) { Node h = head; setHead(node); if (propagate > 0 || h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0 ) { Node s = node.next; if (s == null || s.isShared()) doReleaseShared(); } }
Note CountDownLatch的计数器值只能在创建实例时进行设置,之后不可以对其进行重新设置。换言之,CountDownLatch是一次性的,当其使用完毕后将无法再次利用
参考文献
Java并发编程之美 翟陆续、薛宾田著