CountDownLatch类源码剖析 一、CountDownLatch类的简介 CountDownLatch是一种并发包下的同步辅助工具
,允许一个或多个线程等待其他线程正在执行的一组操作完成,CountDownLatch以给定的计数初始化。await方法会阻塞,直到当前计数因调用countDown方法而归零,之后所有等待的线程都会被释放,并且await的任何后续调用都会立即返回。这是一种一次性现象即计数无法重置
,如果需要重置计数的版本,可以考虑使用CyclicBarrier。CountDownLatch是一种多功能同步工具,可用于多种用途。
初始化计数为1
的CountDownLatch可用作简单的触发器
:所有调用await的线程都在触发器上等待,直到调用countDown的线程打开触发器。
初始化计数为N
的CountDownLatch可用作简单的屏障器
,让一个或多个线程等待,直到N个线程完成某个操作,或某个操作完成N次。
下面是一个使用示例:这里有一对类Driver
和Worker
,其中一组工作线程使用两个倒计时countDownLatch: 第一个倒计时countDownLatch是一个启动信号,用于阻止任何工作线程继续运行,直到驱动线程准备好让它们继续运行; 第二个倒计时countDownLatch是一个完成信号,允许驱动线程等待所有工作线程完成后才继续执行。
另一种典型的用法是将问题分成N个部分,每个部分用一个Runnable来描述,Runnable执行该部分并在countDownLatch上倒计时,然后将所有Runnable排成队列交给一个Executor。当所有子部分都完成后,协调线程就可以从await方法返回。(当线程必须以这种方式反复倒计时时,请使用CyclicBarrier)。
内存一致性语义的说明:在计数达到零之前,线程中调用countDown之前的操作happen-before另一个线程中相应的await成功返回之后的操作。
二、CountDownLatch类的结构 我们把CountDownLatch类中的代码和注释暂且折叠起来,看到其实整个类中的主要方法除了构造函数就是三个方法:await
、await(timeout,unit)
、countDown
,我们下面简单对这三个方法的语义做个介绍。
当我们理解了CountDownLatch类的应用场景后,首先我们先来想一想自己如何实现一套CountDownLatch,有哪些合适的思路呢?当然,考虑JUC并发开发下的基础框架AQS,我们第一想法恐怕就是先初始化AQS的state
值为任务线程数N,表明当前共享模式
下的AQS锁有N个线程在占用,主线程只能等待这些任务线程全部执行结束锁完全释放后,才能【加锁】成功继续执行,否则需要进入同步队列
等待。后面任务线程释放锁时需要查看锁是否完全释放
了,如果是的话则表示所有任务线程运行结束,可以唤醒等待中的主线程了,这就是唤醒时机
。注意,这里主线程【加锁】其实并不是真的去加锁state增加,毕竟任务线程都执行完了,这里的【加锁】语义是对应于CountDownLatch类的应用场景而说的。真实情况也跟我们想的基本相符,下面我们看一下几个主要的方法吧:
await:主线程调用await等待计数为0即所有任务线程都执行结束才会返回,如果发现当前计数大于零需要进入同步队列等待,直到计数变为0或者主线程等待期间被中断抛出InterruptedException。
await(timeout,unit)跟上面的await类似,不过等待一定时间还没有等到计数为0后会超时退出返回false,如果等到了计数变为0则会返回true继续执行。
countDown:递减latch计数,如果计数达到零,则唤醒所有等待线程即主线程。
说完了CountDownLatch的几个主要方法后,我们把关注点再次转移到AQS中,因为所有功能的实现最终还是要依托于我们的强大AQS框架的,跟ReentrantLock一样,本节的CountDownLatch类内部也有一个静态内部类Sync
继承了AQS框架,其重写的方法也就是针对共享模式的tryAcquireShared
和tryReleaseShared
两个钩子函数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981922014374L ; Sync(int count) { setState(count); } int getCount () { return getState(); } protected int tryAcquireShared (int acquires) { return (getState() == 0 ) ? 1 : -1 ; } protected boolean tryReleaseShared (int releases) { for (;;) { int c = getState(); if (c == 0 ) return false ; int nextc = c-1 ; if (compareAndSetState(c, nextc)) return nextc == 0 ; } } }
三、CountDownLatch类的方法 借助B站寒食君的视频图片,看一下CountDownLatch的调用时序:
1.构造方法 介绍了Sync
类后我们终于可以进入CountDownLatch类的方法一探究竟了,首先我们看一下构造方法:
1 2 3 4 public CountDownLatch (int count) { if (count < 0 ) throw new IllegalArgumentException ("count < 0" ); this .sync = new Sync (count); }
构造方法内主要就是初始化CountDownLatch类的属性sync
,用给定参数count
设置初始计数值。
2.await方法 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 public void await () throws InterruptedException { sync.acquireSharedInterruptibly(1 ); } public final void acquireSharedInterruptibly (int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException (); if (tryAcquireShared(arg) < 0 ) doAcquireSharedInterruptibly(arg); } private void doAcquireSharedInterruptibly (int arg) throws InterruptedException { final Node node = addWaiter(Node.SHARED); boolean failed = true ; try { for (;;) { final Node p = node.predecessor(); if (p == head) { int r = tryAcquireShared(arg); if (r >= 0 ) { setHeadAndPropagate(node, r); p.next = null ; failed = false ; return ; } } if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) throw new InterruptedException (); } } finally { if (failed) cancelAcquire(node); } }
我们重点看一下setHeadAndPropagate
方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 private void setHeadAndPropagate (Node node, int propagate) { Node h = head; setHead(node); if (propagate > 0 || h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0 ) { Node s = node.next; if (s == null || s.isShared()) doReleaseShared(); } }
3.await超时版本 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 public boolean await (long timeout, TimeUnit unit) throws InterruptedException { return sync.tryAcquireSharedNanos(1 , unit.toNanos(timeout)); } public final boolean tryAcquireSharedNanos (int arg, long nanosTimeout) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException (); return tryAcquireShared(arg) >= 0 || doAcquireSharedNanos(arg, nanosTimeout); } private boolean doAcquireSharedNanos (int arg, long nanosTimeout) throws InterruptedException { if (nanosTimeout <= 0L ) return false ; final long deadline = System.nanoTime() + nanosTimeout; final Node node = addWaiter(Node.SHARED); boolean failed = true ; try { for (;;) { final Node p = node.predecessor(); if (p == head) { int r = tryAcquireShared(arg); if (r >= 0 ) { setHeadAndPropagate(node, r); p.next = null ; failed = false ; return true ; } } nanosTimeout = deadline - System.nanoTime(); if (nanosTimeout <= 0L ) return false ; if (shouldParkAfterFailedAcquire(p, node) && nanosTimeout > spinForTimeoutThreshold) LockSupport.parkNanos(this , nanosTimeout); if (Thread.interrupted()) throw new InterruptedException (); } } finally { if (failed) cancelAcquire(node); } }
可以看出,超时版本的await方法与一开始说的await方法区别也不大,仅仅多了一个超时时间的限制。
3.countDown方法 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 public void countDown () { sync.releaseShared(1 ); } public final boolean releaseShared (int arg) { if (tryReleaseShared(arg)) { doReleaseShared(); return true ; } return false ; } protected boolean tryReleaseShared (int releases) { for (;;) { int c = getState(); if (c == 0 ) return false ; int nextc = c-1 ; if (compareAndSetState(c, nextc)) return nextc == 0 ; } }