CountDownLatch类源码剖析

CountDownLatch类源码剖析

一、CountDownLatch类的简介

CountDownLatch是一种并发包下的同步辅助工具,允许一个或多个线程等待其他线程正在执行的一组操作完成,CountDownLatch以给定的计数初始化。await方法会阻塞,直到当前计数因调用countDown方法而归零,之后所有等待的线程都会被释放,并且await的任何后续调用都会立即返回。这是一种一次性现象即计数无法重置,如果需要重置计数的版本,可以考虑使用CyclicBarrier。CountDownLatch是一种多功能同步工具,可用于多种用途。

  1. 初始化计数为1的CountDownLatch可用作简单的触发器:所有调用await的线程都在触发器上等待,直到调用countDown的线程打开触发器。
  2. 初始化计数为N的CountDownLatch可用作简单的屏障器,让一个或多个线程等待,直到N个线程完成某个操作,或某个操作完成N次。

下面是一个使用示例:这里有一对类DriverWorker,其中一组工作线程使用两个倒计时countDownLatch:
第一个倒计时countDownLatch是一个启动信号,用于阻止任何工作线程继续运行,直到驱动线程准备好让它们继续运行;
第二个倒计时countDownLatch是一个完成信号,允许驱动线程等待所有工作线程完成后才继续执行。

image-20231118123123455

另一种典型的用法是将问题分成N个部分,每个部分用一个Runnable来描述,Runnable执行该部分并在countDownLatch上倒计时,然后将所有Runnable排成队列交给一个Executor。当所有子部分都完成后,协调线程就可以从await方法返回。(当线程必须以这种方式反复倒计时时,请使用CyclicBarrier)。

image-20231118161900918

内存一致性语义的说明:在计数达到零之前,线程中调用countDown之前的操作happen-before另一个线程中相应的await成功返回之后的操作。

二、CountDownLatch类的结构

我们把CountDownLatch类中的代码和注释暂且折叠起来,看到其实整个类中的主要方法除了构造函数就是三个方法:awaitawait(timeout,unit)countDown,我们下面简单对这三个方法的语义做个介绍。

image-20231118163152155

当我们理解了CountDownLatch类的应用场景后,首先我们先来想一想自己如何实现一套CountDownLatch,有哪些合适的思路呢?当然,考虑JUC并发开发下的基础框架AQS,我们第一想法恐怕就是先初始化AQS的state值为任务线程数N,表明当前共享模式下的AQS锁有N个线程在占用,主线程只能等待这些任务线程全部执行结束锁完全释放后,才能【加锁】成功继续执行,否则需要进入同步队列等待。后面任务线程释放锁时需要查看锁是否完全释放了,如果是的话则表示所有任务线程运行结束,可以唤醒等待中的主线程了,这就是唤醒时机。注意,这里主线程【加锁】其实并不是真的去加锁state增加,毕竟任务线程都执行完了,这里的【加锁】语义是对应于CountDownLatch类的应用场景而说的。真实情况也跟我们想的基本相符,下面我们看一下几个主要的方法吧:

  • await:主线程调用await等待计数为0即所有任务线程都执行结束才会返回,如果发现当前计数大于零需要进入同步队列等待,直到计数变为0或者主线程等待期间被中断抛出InterruptedException。
  • await(timeout,unit)跟上面的await类似,不过等待一定时间还没有等到计数为0后会超时退出返回false,如果等到了计数变为0则会返回true继续执行。
  • countDown:递减latch计数,如果计数达到零,则唤醒所有等待线程即主线程。

说完了CountDownLatch的几个主要方法后,我们把关注点再次转移到AQS中,因为所有功能的实现最终还是要依托于我们的强大AQS框架的,跟ReentrantLock一样,本节的CountDownLatch类内部也有一个静态内部类Sync继承了AQS框架,其重写的方法也就是针对共享模式的tryAcquireSharedtryReleaseShared两个钩子函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

Sync(int count) {
// 设置state值表示需要执行的任务线程数
setState(count);
}

int getCount() {
// 返回state值
return getState();
}

// AQS对于tryAcquireShared方法语义规定:如果返回负数表示尝试加锁失败;如果返回0表示加锁成功但不会唤醒后面的所有等待线程;如果返回正数表示加锁成功并且会唤醒后面的所有等待线程
protected int tryAcquireShared(int acquires) {
// 直到计数变为0才能加锁成功,否则需要一直等待全部任务执行完毕计数变为0
return (getState() == 0) ? 1 : -1;
}

protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 自旋保证锁一定可以释放成功
for (;;) {
int c = getState();
// 计数为0时释放锁没有意义,返回false
if (c == 0)
return false;
int nextc = c-1;
// 完全释放锁后返回true,表明可以唤醒同步队列的等待线程了
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}

三、CountDownLatch类的方法

借助B站寒食君的视频图片,看一下CountDownLatch的调用时序:

image-20231118175503467

1.构造方法

介绍了Sync类后我们终于可以进入CountDownLatch类的方法一探究竟了,首先我们看一下构造方法:

1
2
3
4
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

构造方法内主要就是初始化CountDownLatch类的属性sync,用给定参数count设置初始计数值。

2.await方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

// AQS的acquireSharedInterruptibly流程方法
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 进入方法时发现当前线程中断后抛出InterruptedException
if (Thread.interrupted())
throw new InterruptedException();
// 尝试获取共享锁,在所有任务线程调用countDown使计数归零前都是返回-1
if (tryAcquireShared(arg) < 0)
// 尝试获取共享锁,
doAcquireSharedInterruptibly(arg);
}


// AQS中的doAcquireSharedInterruptibly可中断阻塞等待加锁方法
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 当前等待线程入同步队列
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
// 第一个有效节点可以尝试获取锁
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
// 成功获取到锁,这里返回的r就是1作为传播值,表明唤醒后继等待线程
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// 主要逻辑就是阻塞等待唤醒,可以参考之前ReentrantLock源码剖析的讲解
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
// 这里针对阻塞等待过程中如果发生了中断则抛出InterruptedException
throw new InterruptedException();
}
} finally {
if (failed)
// 发生异常,取消正在获取锁的线程节点
cancelAcquire(node);
}
}

我们重点看一下setHeadAndPropagate方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
private void setHeadAndPropagate(Node node, int propagate) {
Node h = head; // Record old head for check below
setHead(node);
/*
* Try to signal next queued node if:
* Propagation was indicated by caller,
* or was recorded (as h.waitStatus either before
* or after setHead) by a previous operation
* (note: this uses sign-check of waitStatus because
* PROPAGATE status may transition to SIGNAL.)
* and
* The next node is waiting in shared mode,
* or we don't know, because it appears null
*
* The conservatism in both of these checks may cause
* unnecessary wake-ups, but only when there are multiple
* racing acquires/releases, so most need signals now or soon
* anyway.
*/
// CountDownLatch中这里的传播值大于零,因此会唤醒后面的所有共享模式的Node节点(其实CountDownLatch就是共享模式AQS的应用,所以这里等价于唤醒所有等待线程)
if (propagate > 0 || h == null || h.waitStatus < 0 ||
(h = head) == null || h.waitStatus < 0) {
Node s = node.next;
if (s == null || s.isShared())
// 唤醒后继节点
doReleaseShared();
}
}

3.await超时版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
// 调用AQS的尝试获取锁的超时版本方法tryAcquireSharedNanos
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
// 进入方法时检查当前线程是否被中断,如果是的话抛出InterruptedException
if (Thread.interrupted())
throw new InterruptedException();
// 尝试获取锁,如果没有获取到就调用超时版本的doAcquireSharedNanos方法
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}

private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
// 等待时间小于等于零,直接返回false表示没有拿到锁
if (nanosTimeout <= 0L)
return false;
// 计算超时时刻
final long deadline = System.nanoTime() + nanosTimeout;
// 下面的代码逻辑与前面的await方法中的一样
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return true;
}
}
// 计算剩余等待时间
nanosTimeout = deadline - System.nanoTime();
// 超时时刻已到,返回false表示没有拿到锁
if (nanosTimeout <= 0L)
return false;
// 调用LockSupport的超时阻塞方法,注意如果剩余阻塞等待时间还不到spinForTimeoutThreshold,那么选择自旋可能程序效率更高,因为线程状态切换也需要一定开销(操作系统调度线程从用户态到内核态需要一定时钟周期)
if (shouldParkAfterFailedAcquire(p, node) &&
nanosTimeout > spinForTimeoutThreshold)
LockSupport.parkNanos(this, nanosTimeout);
if (Thread.interrupted())
throw new InterruptedException();
}
} finally {
if (failed)
// 发生异常,取消正在获取锁的线程节点
cancelAcquire(node);
}
}

可以看出,超时版本的await方法与一开始说的await方法区别也不大,仅仅多了一个超时时间的限制。

3.countDown方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public void countDown() {
sync.releaseShared(1);
}

// AQS的releaseShared方法
public final boolean releaseShared(int arg) {
// 任务全都执行结束计数归零后tryReleaseShared返回true,可以唤醒等待的主线程了
if (tryReleaseShared(arg)) {
// 唤醒第一个有效线程,第一个有效线程会唤醒第二个有效线程...
doReleaseShared();
return true;
}
return false;
}

protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 自旋保证锁一定可以释放成功
for (;;) {
int c = getState();
// 计数为0时释放锁没有意义,返回false
if (c == 0)
return false;
int nextc = c-1;
// 完全释放锁后返回true,表明可以唤醒同步队列的等待线程了
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}