Java集合之ConcurrentHashMap类源码剖析

ConcurrentHashMap类源码剖析

对于ConcurrentHashMap我们分为JDK7和JDK8两个版本分别讲解,因为这两个版本的差异非常大,先说结论:

  • Java7中ConcurrentHashMap使用的分段锁,也就是每一个Segment上同时只有一个线程可以操作,每一个Segment都是一个类似HashMap的散列数组的结构,它可以扩容,它的冲突会转化为链表。但是Segment的个数一旦初始化就不能改变。
  • Java8 中的ConcurrentHashMap使用的Synchronized锁加CAS的机制。结构也由Java7中的==Segment数组 + HashEntry数组 + 链表==进化成了==Node数组 + 链表/红黑树==,Node是类似于一个HashEntry的结构。它的冲突在达到一定大小时会转化成红黑树,在冲突小于一定数量时又退回链表。

1.JDK7中的ConcurrentHashMap

1.存储结构

图片

JDK7中ConcurrentHashMap的存储结构如上图,ConcurrnetHashMap由很多个Segment组合,而每一个Segment是一个类似于HashMap的结构,所以每一个HashMap的内部可以进行扩容。但是Segment的个数一旦初始化就不能改变,默认Segment的个数是16个,你也可以认为ConcurrentHashMap默认支持最多16个线程并发。

2.初始化流程

通过ConcurrentHashMap的无参构造探寻ConcurrentHashMap的初始化流程。

1
2
3
4
5
6
7
/**
* Creates a new, empty map with a default initial capacity (16),
* load factor (0.75) and concurrencyLevel (16).
*/
public ConcurrentHashMap() {
this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}

无参构造中调用了有参构造,传入了三个参数的默认值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
* The default initial capacity for this table,
* used when not otherwise specified in a constructor.
*/
static final int DEFAULT_INITIAL_CAPACITY = 16;

/**
* The default load factor for this table, used when not
* otherwise specified in a constructor.
*/
static final float DEFAULT_LOAD_FACTOR = 0.75f;

/**
* The default concurrency level for this table, used when not
* otherwise specified in a constructor.
*/
static final int DEFAULT_CONCURRENCY_LEVEL = 16;

接着看下这个有参构造函数的内部实现逻辑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@SuppressWarnings("unchecked")
public ConcurrentHashMap(int initialCapacity,float loadFactor, int concurrencyLevel) {
// 参数校验
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
// 校验并发级别大小,若大于 1<<16 则重置为 65536
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
// 2的幂次,2^sshift^ == ssize
int sshift = 0;
// segment数组大小
int ssize = 1;
// 找到不小于concurrencyLevel的最小的2的幂次值
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
// 记录段偏移量
this.segmentShift = 32 - sshift;
// 记录段掩码
this.segmentMask = ssize - 1;
// 设置容量,不能超过1<<30
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
// c=容量/ssize,默认16/16=1,这里是计算每个Segment中的HashEntry数组的容量
int c = initialCapacity / ssize;
// 如果无法整除,则c向上取整
if (c * ssize < initialCapacity)
++c;
// Segment的表容量最小是2,且必须是2的幂次
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c)
cap <<= 1;
// 创建Segment数组,设置segments[0]
Segment<K,V> s0 = new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
this.segments = ss;
}

总结一下在JDK7中ConcurrentHashMap的初始化逻辑:

  1. 必要参数校验loadFactor<=0||initialCapacity<0||concurrencyLevel<=0
  2. 校验并发级别concurrencyLevel大小,如果大于最大值,重置为最大值。
  3. 寻找并发级别concurrencyLevel之上最近的2的幂次方值,作为初始化容量大小,默认是16。
  4. 记录segmentShift偏移量,这个值为【容量=2的N次方】中的N,在后面put时计算位置时会用到,默认是32-sshift=28。
  5. 记录segmentMask,默认是ssize-1=16-1=15。
  6. 初始化segments[0],默认大小为2,负载因子0.75,扩容阀值是2*0.75=1.5,插入第二个值时才会进行扩容。

重点:Segment段数组的长度是2的幂次,由【concurrencyLevel】参数控制,一开始只会初始化segments[0],并且其HashEntry数组的长度由【initialCapacity和concurrencyLevel】共同控制,最小为2。

3.put操作

接着上面的初始化参数继续查看put方法源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
public V put(K key, V value) {
Segment<K,V> s;
// ConcurrentHashMap不允许null值
if (value == null)
throw new NullPointerException();
int hash = hash(key);
// hash值无符号右移28位,然后与segmentMask=15做与运算
// 其实就是根据hash值二进制的高位取模得到Segment数组的索引位置
int j = (hash >>> segmentShift) & segmentMask;
// 检查对应位置的Segement是否已经完成初始化
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck in ensureSegment
(segments, (j << SSHIFT) + SBASE)) == null)
// 如果查找到的Segment为空,先确保初始化
s = ensureSegment(j);
// Segment对象执行put插入逻辑
return s.put(key, hash, value, false);
}

@SuppressWarnings("unchecked")
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
// 1.如果k下标的Sgement不为null说明已经有线程提前完成了Segment的初始化,可以直接退出并返回seg
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
// 使用初始化逻辑中预先设置的segment[0]作为后续Segment初始化原型,避免了Segment中字段的重新计算
Segment<K,V> proto = ss[0];
// 获取segment[0]里的HashEntry<K,V>数组初始化长度
int cap = proto.table.length;
// 获取segment[0]里的HashEntry<K,V>数组的负载因子
float lf = proto.loadFactor;
// 计算扩容阈值
int threshold = (int)(cap * lf);
// 创建一个cap容量的HashEntry数组
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
// 2.再次检查,如果k下标的Sgement不为null说明已经有线程提前完成了Segment的初始化,可以直接退出并返回seg
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // recheck
// 创建新的Segment对象
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
// 自旋配合CAS保证Segment的正确初始化
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
// 使用CAS赋值,只会成功一次
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}

上面的源码分析了ConcurrentHashMap在put一个数据时的处理流程,下面梳理下具体流程。

  1. 计算要put的key的位置,获取指定位置的Segment。
  2. 如果指定位置的Segment为空,则初始化这个Segment。
    1. 检查计算得到的位置的Segment是否为null。
    2. 为null则继续初始化,使用Segment[0]的容量和负载因子(原型)创建一个HashEntry数组。
    3. 再次检查计算得到的指定位置的Segment是否为 null。
    4. 使用创建的HashEntry数组初始化这个Segment。
    5. 自旋判断计算得到的指定位置的Segment是否为null,使用CAS在这个位置赋值为Segment。
  3. Segment.put插入key,value 值。

重点:插入元素时根据key的hash值的【高位】决定落在Segment数组的哪个位置上,并发初始化Segment段时需要【自旋检查+CAS更新】,保证只会初始化一次。

上面探究了获取Segment段和初始化Segment段的操作,最后一行的Segment的put方法还没有查看,继续分析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
// 获取ReentrantLock独占锁,若获取不到,scanAndLockForPut一直等待加锁成功。
HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
// 计算HashEntry数组中要put的数据位置
int index = (tab.length - 1) & hash;
// CAS获取index下标的值
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
// 检查是否key已经存在,如果存在则遍历链表寻找位置,找到后替换value
if (e != null) {
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {
// 遍历链表没有找到符合的Entry,node不为null表示加锁时提前创建好的待插入节点,执行链表的头插法
if (node != null)
node.setNext(first);
// 遍历链表没有找到符合的Entry,node为null表示需要手动创建待插入节点,执行链表的头插法
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
// 容量大于扩容阈值,小于最大容量,进行扩容
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node);
else
// index位置赋值node,node可能是一个元素,也可能是一个链表的表头
setEntryAt(tab, index, node);
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}

重点:实际插入键值对时需要针对Segment加锁,如果已存在key则更新value,否则执行头插法。

这里面的第一步中的scanAndLockForPut操作这里没有介绍,这个方法做的操作就是不断的自旋tryLock()获取锁。当自旋次数大于指定次数时,使用lock()阻塞获取锁。在自旋时顺便获取下hash位置的HashEntry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
// 自旋获取锁
while (!tryLock()) {
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
// 自旋过程中顺便遍历链表
if (e == null) {
// 加锁过程中就会顺便提前创建HashEntry对象
if (node == null) // speculatively create node
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
else if (++retries > MAX_SCAN_RETRIES) {
// 自旋达到指定次数后,阻塞等待加锁成功
lock();
break;
}
// 链表头节点改变,说明此时有新的节点通过头插法加入,重新遍历
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
e = first = f; // re-traverse if entry changed
retries = -1;
}
}
return node;
}

4.get操作

到这里就很简单了,get方法只需要两步即可。

  1. 计算得到key所在的HashEntry在Segment数组的具体位置。
  2. 遍历HashEntry链表查找相同key的value值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
// 计算得到 key 所在的 HashEntry 在 Segment 数组的具体位置
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
// 遍历 HashEntry 链表查找相同 key 的 value 值
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}

get操作的高效之处在于整个get过程不需要加锁,除非读到的值是空才会加锁重读。我们知道HashTable容器的get方法是需要加锁的,那么ConcurrentHashMap的get操作是如何做到不加锁的呢?原因是它的get方法里将要使用的共享变量都定义成volatile类型,如用于统计当前Segement大小的count字段和用于存储值的HashEntry的value。定义成volatile的变量,能够在线程之间保持可见性,能够被多线程同时读,并且保证不会读到过期的值,但是只能被单线程写,在get操作里只需要读不需要写共享变量count和value,所以可以不用加锁。之所以不会读到过期的值,是因为根据Java内存模型的happen before原则,对volatile字段的写入操作先于读操作,即使两个线程同时修改和获取volatile变量,get操作也能拿到最新的值,这是用volatile替换锁的经典应用场景。

5.size操作

如果要统计整个ConcurrentHashMap里元素的大小,就必须统计所有Segment里元素的大小后求和。Segment里的全局变量count是一个volatile变量,那么在多线程场景下,是不是直接把所有Segment的count相加就可以得到整个ConcurrentHashMap大小了呢?不是的,虽然相加时可以获取每个Segment的count的最新值,但是可能累加前使用的count发生了变化,那么统计结果就不准了。所以,最安全的做法是在统计size的时候把所有Segment的put、remove和clean方法全部锁住,但是这种做法显然非常低效

因为在累加count操作过程中,之前累加过的count发生变化的几率非常小,所以ConcurrentHashMap的做法是先尝试2次通过不锁住Segment的方式来统计各个Segment大小,如果统计的过程中,容器的count发生了变化,则再采用加锁的方式来统计所有Segment的大小

那么ConcurrentHashMap是如何判断在统计的时候容器是否发生了变化呢?使用modCount变量,在put、remove和clean方法里操作元素前都会将变量modCount进行加1,那么在统计size前后比较modCount是否发生变化,从而得知容器的大小是否发生变化。

image-20240905233053963