ConcurrentHashMap类源码剖析 对于ConcurrentHashMap我们分为JDK7和JDK8两个版本分别讲解,因为这两个版本的差异非常大,先说结论:
Java7中ConcurrentHashMap
使用的分段锁,也就是每一个Segment上同时只有一个线程可以操作,每一个Segment
都是一个类似HashMap
的散列数组的结构,它可以扩容,它的冲突会转化为链表。但是Segment
的个数一旦初始化就不能改变。
Java8 中的ConcurrentHashMap
使用的Synchronized
锁加CAS
的机制。结构也由Java7中的==Segment数组 + HashEntry数组 + 链表==进化成了==Node数组 + 链表/红黑树==,Node是类似于一个HashEntry的结构。它的冲突在达到一定大小时会转化成红黑树,在冲突小于一定数量时又退回链表。
1.JDK7中的ConcurrentHashMap 1.存储结构
JDK7中ConcurrentHashMap的存储结构如上图,ConcurrnetHashMap由很多个Segment组合,而每一个Segment是一个类似于HashMap的结构,所以每一个HashMap的内部可以进行扩容。但是Segment的个数一旦初始化就不能改变
,默认Segment的个数是16个,你也可以认为ConcurrentHashMap默认支持最多16个线程并发。
2.初始化流程 通过ConcurrentHashMap的无参构造探寻ConcurrentHashMap的初始化流程。
1 2 3 4 5 6 7 public ConcurrentHashMap () { this (DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL); }
无参构造中调用了有参构造,传入了三个参数的默认值:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 static final int DEFAULT_INITIAL_CAPACITY = 16 ;static final float DEFAULT_LOAD_FACTOR = 0.75f ;static final int DEFAULT_CONCURRENCY_LEVEL = 16 ;
接着看下这个有参构造函数的内部实现逻辑:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 @SuppressWarnings("unchecked") public ConcurrentHashMap (int initialCapacity,float loadFactor, int concurrencyLevel) { if (!(loadFactor > 0 ) || initialCapacity < 0 || concurrencyLevel <= 0 ) throw new IllegalArgumentException (); if (concurrencyLevel > MAX_SEGMENTS) concurrencyLevel = MAX_SEGMENTS; int sshift = 0 ; int ssize = 1 ; while (ssize < concurrencyLevel) { ++sshift; ssize <<= 1 ; } this .segmentShift = 32 - sshift; this .segmentMask = ssize - 1 ; if (initialCapacity > MAXIMUM_CAPACITY) initialCapacity = MAXIMUM_CAPACITY; int c = initialCapacity / ssize; if (c * ssize < initialCapacity) ++c; int cap = MIN_SEGMENT_TABLE_CAPACITY; while (cap < c) cap <<= 1 ; Segment<K,V> s0 = new Segment <K,V>(loadFactor, (int )(cap * loadFactor), (HashEntry<K,V>[])new HashEntry [cap]); Segment<K,V>[] ss = (Segment<K,V>[])new Segment [ssize]; UNSAFE.putOrderedObject(ss, SBASE, s0); this .segments = ss; }
总结一下在JDK7中ConcurrentHashMap的初始化逻辑:
必要参数校验loadFactor<=0||initialCapacity<0||concurrencyLevel<=0
。
校验并发级别concurrencyLevel
大小,如果大于最大值,重置为最大值。
寻找并发级别concurrencyLevel
之上最近的2的幂次方值,作为初始化容量大小,默认是16。
记录segmentShift
偏移量,这个值为【容量=2的N次方】中的N,在后面put时计算位置时会用到,默认是32-sshift=28。
记录segmentMask
,默认是ssize-1=16-1=15。
初始化segments[0]
,默认大小为2,负载因子0.75,扩容阀值是2*0.75=1.5,插入第二个值时才会进行扩容。
重点:Segment段数组的长度是2的幂次,由【concurrencyLevel】参数控制,一开始只会初始化segments[0],并且其HashEntry数组的长度由【initialCapacity和concurrencyLevel】共同控制,最小为2。
3.put操作 接着上面的初始化参数继续查看put方法源码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 public V put (K key, V value) { Segment<K,V> s; if (value == null ) throw new NullPointerException (); int hash = hash(key); int j = (hash >>> segmentShift) & segmentMask; if ((s = (Segment<K,V>)UNSAFE.getObject (segments, (j << SSHIFT) + SBASE)) == null ) s = ensureSegment(j); return s.put(key, hash, value, false ); } @SuppressWarnings("unchecked") private Segment<K,V> ensureSegment (int k) { final Segment<K,V>[] ss = this .segments; long u = (k << SSHIFT) + SBASE; Segment<K,V> seg; if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null ) { Segment<K,V> proto = ss[0 ]; int cap = proto.table.length; float lf = proto.loadFactor; int threshold = (int )(cap * lf); HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry [cap]; if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null ) { Segment<K,V> s = new Segment <K,V>(lf, threshold, tab); while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null ) { if (UNSAFE.compareAndSwapObject(ss, u, null , seg = s)) break ; } } } return seg; }
上面的源码分析了ConcurrentHashMap在put一个数据时的处理流程,下面梳理下具体流程。
计算要put的key的位置,获取指定位置的Segment。
如果指定位置的Segment为空,则初始化这个Segment。
检查计算得到的位置的Segment是否为null。
为null则继续初始化,使用Segment[0]的容量和负载因子(原型)创建一个HashEntry数组。
再次检查计算得到的指定位置的Segment是否为 null。
使用创建的HashEntry数组初始化这个Segment。
自旋判断计算得到的指定位置的Segment是否为null,使用CAS在这个位置赋值为Segment。
Segment.put插入key,value 值。
重点:插入元素时根据key的hash值的【高位】决定落在Segment数组的哪个位置上,并发初始化Segment段时需要【自旋检查+CAS更新】,保证只会初始化一次。
上面探究了获取Segment段和初始化Segment段的操作,最后一行的Segment的put方法还没有查看,继续分析。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 final V put (K key, int hash, V value, boolean onlyIfAbsent) { HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value); V oldValue; try { HashEntry<K,V>[] tab = table; int index = (tab.length - 1 ) & hash; HashEntry<K,V> first = entryAt(tab, index); for (HashEntry<K,V> e = first;;) { if (e != null ) { K k; if ((k = e.key) == key || (e.hash == hash && key.equals(k))) { oldValue = e.value; if (!onlyIfAbsent) { e.value = value; ++modCount; } break ; } e = e.next; } else { if (node != null ) node.setNext(first); else node = new HashEntry <K,V>(hash, key, value, first); int c = count + 1 ; if (c > threshold && tab.length < MAXIMUM_CAPACITY) rehash(node); else setEntryAt(tab, index, node); ++modCount; count = c; oldValue = null ; break ; } } } finally { unlock(); } return oldValue; }
重点:实际插入键值对时需要针对Segment加锁,如果已存在key则更新value,否则执行头插法。
这里面的第一步中的scanAndLockForPut
操作这里没有介绍,这个方法做的操作就是不断的自旋tryLock()
获取锁。当自旋次数大于指定次数时,使用lock()
阻塞获取锁。在自旋时顺便获取下hash位置的HashEntry
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 private HashEntry<K,V> scanAndLockForPut (K key, int hash, V value) { HashEntry<K,V> first = entryForHash(this , hash); HashEntry<K,V> e = first; HashEntry<K,V> node = null ; int retries = -1 ; while (!tryLock()) { HashEntry<K,V> f; if (retries < 0 ) { if (e == null ) { if (node == null ) node = new HashEntry <K,V>(hash, key, value, null ); retries = 0 ; } else if (key.equals(e.key)) retries = 0 ; else e = e.next; } else if (++retries > MAX_SCAN_RETRIES) { lock(); break ; } else if ((retries & 1 ) == 0 && (f = entryForHash(this , hash)) != first) { e = first = f; retries = -1 ; } } return node; }
4.get操作 到这里就很简单了,get方法只需要两步即可。
计算得到key所在的HashEntry在Segment数组的具体位置。
遍历HashEntry链表查找相同key的value值。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 public V get (Object key) { Segment<K,V> s; HashEntry<K,V>[] tab; int h = hash(key); long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE; if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null ) { for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile (tab, ((long )(((tab.length - 1 ) & h)) << TSHIFT) + TBASE); e != null ; e = e.next) { K k; if ((k = e.key) == key || (e.hash == h && key.equals(k))) return e.value; } } return null ; }
get操作的高效之处在于整个get过程不需要加锁,除非读到的值是空才会加锁重读。我们知道HashTable容器的get方法是需要加锁的,那么ConcurrentHashMap的get操作是如何做到不加锁的呢?原因是它的get方法里将要使用的共享变量都定义成volatile类型,如用于统计当前Segement大小的count字段和用于存储值的HashEntry的value
。定义成volatile的变量,能够在线程之间保持可见性,能够被多线程同时读,并且保证不会读到过期的值,但是只能被单线程写,在get操作里只需要读不需要写共享变量count和value,所以可以不用加锁。之所以不会读到过期的值,是因为根据Java内存模型的happen before原则,对volatile字段的写入操作先于读操作,即使两个线程同时修改和获取volatile变量,get操作也能拿到最新的值,这是用volatile替换锁的经典应用场景。
5.size操作 如果要统计整个ConcurrentHashMap里元素的大小,就必须统计所有Segment里元素的大小后求和。Segment里的全局变量count是一个volatile变量
,那么在多线程场景下,是不是直接把所有Segment的count相加就可以得到整个ConcurrentHashMap大小了呢?不是的,虽然相加时可以获取每个Segment的count的最新值,但是可能累加前使用的count发生了变化,那么统计结果就不准了。所以,最安全的做法是在统计size的时候把所有Segment的put、remove和clean方法全部锁住,但是这种做法显然非常低效
。
因为在累加count操作过程中,之前累加过的count发生变化的几率非常小,所以ConcurrentHashMap的做法是先尝试2次通过不锁住Segment的方式来统计各个Segment大小,如果统计的过程中,容器的count发生了变化,则再采用加锁的方式来统计所有Segment的大小
。
那么ConcurrentHashMap是如何判断在统计的时候容器是否发生了变化呢?使用modCount
变量,在put、remove和clean方法里操作元素前都会将变量modCount进行加1,那么在统计size前后比较modCount是否发生变化,从而得知容器的大小是否发生变化。