0%

Java多线程之CountDownLatch

这里就JUC包中的CountDownLatch类做相关介绍

abstract.jpeg

概述

JUC包中的CountDownLatch类是一个同步工具类,可实现线程间的通信。其典型方法如下所示

1
2
3
4
5
6
7
8
9
10
11
// 创建一个指定计数器值的CountDownLatch实例
public CountDownLatch(int count);

// 当前线程阻塞等待CountDownLatch实例的计数器值为0
public void await() throws InterruptedException;

// 支持超时的阻塞等待; 返回true: CountDownLatch实例的计数器值为0; 返回false: 超时
public boolean await(long timeout, TimeUnit unit);

// CountDownLatch实例的计数器值减1
public void countDown();

基本使用方法也很简单。首先创建一个指定计数器值的CountDownLatch实例,每当其他线程完成任务时就通过countDown方法将计数器值减1。这样当计数器的值为0时,之前由于调用await方法而被阻塞的线程就会结束等待,恢复执行

实践

CountDownLatch的典型应用场景,大体可分为两类:结束信号、开始信号

结束信号

主线程创建、启动N个异步任务,我们期望当这N个任务全部执行完毕结束后,主线程才可以继续往下执行。即将CountDownLatch作为任务的结束信号来使用。示例代码如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
public class CountDownLatchTest1 {

@Test
public void test1() throws InterruptedException {
ExecutorService threadPool = Executors.newFixedThreadPool(5);
CountDownLatch doneSignal = new CountDownLatch(3);

Arrays.asList("Task 1","Task 2","Task 3")
.stream()
.map( name -> new Task(name, doneSignal) )
.forEach( task -> threadPool.execute(task) );

// 阻塞等待, 直到计数器变为0。即所有任务均完成
doneSignal.await();
System.out.println("所有任务均完成");
}

@AllArgsConstructor
private static class Task implements Runnable{

private String taskName;

private CountDownLatch doneSignal;

@Override
public void run() {
System.out.println(taskName + " 开始");
// 模拟业务耗时
try{
Thread.sleep( RandomUtils.nextInt(5,9) * 1000 );
}catch (Exception e) {
System.out.println( "Happen Exception: " + e.getMessage());
}
System.out.println(taskName + " 完成");
// 当前任务完成, 则计数器减一
doneSignal.countDown();
}
}
}

测试结果如下所示,符合预期

figure 1.jpeg

开始信号

主线程创建N个异步任务,但这N个任务不能立即开始执行。而需要等待某个共同的前置任务(比如初始化任务)完成后,才允许这N个任务开始执行。即将CountDownLatch作为任务的开始信号来使用。示例代码如下所示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
public class CountDownLatchTest2 {

@Test
public void test1() throws InterruptedException {
ExecutorService threadPool = Executors.newFixedThreadPool(10);
CountDownLatch startSignal = new CountDownLatch(1);

Arrays.asList("Task 1","Task 2","Task 3")
.stream()
.map( name -> new Task(name, startSignal) )
.forEach( task -> threadPool.execute(task) );

// 执行初始化准备工作
System.out.println("初始化准备工作开始");
// 模拟业务耗时
try{
Thread.sleep( RandomUtils.nextInt(5,9) * 1000 );
}catch (Exception e) {
System.out.println( "Happen Exception: " + e.getMessage());
}
System.out.println("初始化准备工作结束");

// 初始化准备工作完成, 则计数器减一
startSignal.countDown();

// 主线程等待所有任务执行完毕
try{ Thread.sleep( 20*1000 ); } catch (Exception e) {}
System.out.println("Game Over");
}

@AllArgsConstructor
private static class Task implements Runnable{

private String taskName;

private CountDownLatch startSignal;

@Override
public void run() {
try{
// 阻塞等待, 直到计数器变为0。 即前置任务完成
startSignal.await();
}catch (InterruptedException e) {
System.out.println( "Happen Exception: " + e.getMessage());
}

System.out.println(taskName + " 开始");
// 模拟业务耗时
try{
Thread.sleep( RandomUtils.nextInt(5,9) * 1000 );
}catch (Exception e) {
System.out.println( "Happen Exception: " + e.getMessage());
}
System.out.println(taskName + " 完成");
}
}
}

测试结果如下所示,符合预期

figure 2.jpeg

基本原理

构造器

CountDownLatch类实现过程同样依赖于AQS。在构建CountDownLatch实例过程时,一方面,通过sync变量持有AQS的实现类Sync;另一方面,通过AQS的state字段来存储计数器值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public class CountDownLatch {
private final Sync sync;

public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

private static final class Sync extends AbstractQueuedSynchronizer {
Sync(int count) {
setState(count);
}
}
}

...

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
private volatile int state;

protected final void setState(int newState) {
state = newState;
}
}

await方法

首先来看CountDownLatch的await方法。其委托sync调用AQS的acquireSharedInterruptibly方法,从方法名也可以看到其是对AQS中共享锁的使用。并根据当前计数器的值是否为0,来判断该线程是继续执行还是应该被阻塞。可以看到事实上AQS只是定义了是否需要阻塞线程的tryAcquireShared方法,具体的规则需要CountDownLatch类来进行实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public class CountDownLatch {
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

private static final class Sync extends AbstractQueuedSynchronizer {
// 判断当前计数器值是否为0, 是则返回1; 否则返回-1
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
}
}

...

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
// 线程被中断则直接抛出异常
if (Thread.interrupted())
throw new InterruptedException();

if (tryAcquireShared(arg) < 0)
// 当前计数器不为0, 需进入AQS的队列准备阻塞
doAcquireSharedInterruptibly(arg);
}

// 需要子类去实现
protected int tryAcquireShared(int arg) {
throw new UnsupportedOperationException();
}
}

当tryAcquireShared方法结果小于0时,即当前计数器不为0时,AQS如何通过doAcquireSharedInterruptibly方法实现阻塞呢?结合相关源码可以看到,首先通过addWaiter方法将当前线程包装为一个node实例,并将其加入AQS队列。在入队过程中需要注意,如果队列为空则其并不是直接将该node实例加入队列。而是先构造一个哨兵节点来入队,然后在enq方法下一轮for循环才将该node实例加入队列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
// 将当前线程包装为node,加入AQS队列,并返回该node实例
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
// 获取 node 的前驱节点
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

private Node addWaiter(Node mode) {
// 将当前线程包装为一个node实例
Node node = new Node(Thread.currentThread(), mode);
Node pred = tail;
// 队列的尾指针不为空, 说明队列不为空, 则利用尾插法将node入队
if (pred != null) {
node.prev = pred;
if (compareAndSetTail(pred, node)) {
pred.next = node;
// 入队完毕, 直接返回该node
return node;
}
}
// 队列为空, 则先构建一个哨兵节点、入队,再将该node入队
enq(node);
return node;
}

private Node enq(final Node node) {
for (;;) {
Node t = tail;
// 队尾指针为空, 则先进行队列的初始化
if (t == null) {
// 构建一个哨兵节点并入队
if (compareAndSetHead(new Node()))
tail = head;
} else {
// 将node入队
node.prev = t;
if (compareAndSetTail(t, node)) {
t.next = node;
return t;
}
}
}
}
}

然后通过shouldParkAfterFailedAcquire方法修改前驱节点的waitStatus。如果前驱节点的waitStatus字段是初始值0的话,需在第一轮for循环中进入shouldParkAfterFailedAcquire方法时,通过compareAndSetWaitStatus(pred, ws, Node.SIGNAL)方法将前驱节点的waitStatus字段修改为Node.SIGNAL(即-1)。这样在开始下一轮for循环时,shouldParkAfterFailedAcquire方法即会返回true。进而执行parkAndCheckInterrupt方法,利用LockSupport.park完成线程阻塞

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
// 获取前驱节点的waitStatus字段值
int ws = pred.waitStatus;
if (ws == Node.SIGNAL)
return true;
if (ws > 0) {
do {
node.prev = pred = pred.prev;
} while (pred.waitStatus > 0);
pred.next = node;
} else {
compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
}
return false;
}

private final boolean parkAndCheckInterrupt() {
LockSupport.park(this);
return Thread.interrupted();
}

countDown方法

CountDownLatch的countDown方法类似。其同样是委托sync调用AQS的releaseShared方法。然后AQS执行tryReleaseShared方法,CountDownLatch类负责实现具体的规则逻辑。如果自减后当前计数器为0,则说明需要唤醒之前通过await方法而被阻塞的线程。然后通过AQS的doReleaseShared方法实现唤醒。具体地,其是从头节点的后继节点开始唤醒。因为前面已经说过,AQS队列的第一个节点(即头节点)只是一个哨兵节点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
public class CountDownLatch {
public void countDown() {
sync.releaseShared(1);
}

private static final class Sync extends AbstractQueuedSynchronizer {
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
}

...

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}

// 需要子类去实现
protected boolean tryReleaseShared(int arg) {
throw new UnsupportedOperationException();
}

private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
// 唤醒头节点的后继节点
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

private void unparkSuccessor(Node node) {
int ws = node.waitStatus;
if (ws < 0)
compareAndSetWaitStatus(node, ws, 0);
Node s = node.next;
if (s == null || s.waitStatus > 0) {
s = null;
for (Node t = tail; t != null && t != node; t = t.prev)
if (t.waitStatus <= 0)
s = t;
}
if (s != null)
LockSupport.unpark(s.thread);
}
}

这里补充说明下,当上文由于调用await方法而被阻塞的线程唤醒后,其会在doAcquireSharedInterruptibly方法的for循环中恢复执行。此时由于tryAcquireShared方法的返回值r大于0满足条件,故其进入setHeadAndPropagate方法。在该方法中,其将自身重新设置为AQS的头节点。并通过doReleaseShared方法继续唤醒它的后继节点。从而实现将AQS队列被阻塞的线程全部唤醒

1
2
3
4
5
6
7
8
9
10
11
private void setHeadAndPropagate(Node node, int propagate) {
Node h = head; // Record old head for check below
setHead(node);

if (propagate > 0 || h == null || h.waitStatus < 0 ||
(h = head) == null || h.waitStatus < 0) {
Node s = node.next;
if (s == null || s.isShared())
doReleaseShared();
}
}

Note

CountDownLatch的计数器值只能在创建实例时进行设置,之后不可以对其进行重新设置。换言之,CountDownLatch是一次性的,当其使用完毕后将无法再次利用

参考文献

  1. Java并发编程之美 翟陆续、薛宾田著
请我喝杯咖啡捏~

欢迎关注我的微信公众号:青灯抽丝