ConcurrentHashMap源码JDK7

为了提高并发度,一个HashMap被拆分成多个子HashMap,每个HashMap被称为Segment,多个线程操作多个Segment相互独立。每个Segment都继承自ReentrantLockSegment的数量等于锁的数量,这些锁彼此之间相互独立,即所谓的分段锁

JDK7中ConcurrentHashMap数据结构示意图

segmentShiftsegmentMask是为了方便计算Segment数组下标。segmentShift默认是28,segmentMask默认是15

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public class ConcurrentHashMap<K, V> extends AbstractMap<K, V> implements ConcurrentMap<K, V>, Serializable {
static final int DEFAULT_INITIAL_CAPACITY = 16;
static final float DEFAULT_LOAD_FACTOR = 0.75f;
static final int DEFAULT_CONCURRENCY_LEVEL = 16; // Segment数组的默认大小,即默认并发级别
static final int MAXIMUM_CAPACITY = 1 << 30;
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
static final int MAX_SEGMENTS = 1 << 16;
static final int RETRIES_BEFORE_LOCK = 2;
final int segmentMask; // 用于方便计算Segment数组下标
final int segmentShift; // 用于方便计算Segment数组下标
final Segment<K,V>[] segments;
transient Set<K> keySet;
transient Set<Map.Entry<K,V>> entrySet;
transient Collection<V> values;
static final class Segment<K,V> extends ReentrantLock implements Serializable {
static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
transient volatile HashEntry<K,V>[] table;
transient int count;
transient int modCount;
transient int threshold;
final float loadFactor;
}
}

构造函数,concurrencyLevel表示并发度,也就是Segment数组的大小默认大小16一旦设定之后就不能再扩容了,且为了提升hash的计算性能,会保证数组初始大小始终是2的整数次方,若concurrencyLevel=9,则在构造函数中会找到比9大且最接近9的2的整数次方,也就是ssize=16,对应的segmentShiftsegmentMask也是为了方便计算hash使用的。

同样loadFactor为负载因子,传给了Segment内部,当每个Segment的元素个数达到一定阈值时进行rehash,虽然Segment个数不能扩容,但每个Segment内部可以扩容

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public ConcurrentHashMap(int initialCapacity) {
this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
public ConcurrentHashMap() {
this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) { // 保证并发度是2的整数次方
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift; // 默认算出来为28
this.segmentMask = ssize - 1; // 默认算出来为15
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize; // 除数容量或数组个数,是每个Segment的初始大小,默认为1
if (c * ssize < initialCapacity) // 该循环的作用其实是向上取整,c是每个HashEntry的数组长度
++c;
int cap = MIN_SEGMENT_TABLE_CAPACITY; // 默认为2,HashEntry的数组最小容量为2
while (cap < c) // c可能不是一个2的整数次幂的数,这里的作用其实就是获取大于等于c的最小的2的整数次幂的数
cap <<= 1;
// 构造第0个Segment
Segment<K,V> s0 = new Segment<K,V>(loadFactor, (int)(cap * loadFactor), (HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize]; // 数组大小为ssize即2的整数次方
UNSAFE.putOrderedObject(ss, SBASE, s0); // 数组的第0个元素赋值为s0
this.segments = ss;
}

hash值是一个32位的整数,segmentShift默认大小为28segmentMask默认为15,则(hash >>> segmentShift) & segmentMask的意思是hash值向右移28再和15进行与操作,即hash值的最高4位作为对应Segment数组下标。该处没有加锁,锁是加在s.put内部,也就是分段加锁。从put方法可以看出ConcurrentHashMap不允许空值和空健的,这也是和HashMap的另一个区别。

1
2
3
4
5
6
7
8
9
10
11
public V put(K key, V value) {
Segment<K,V> s;
if (value == null) // 不允许空值
throw new NullPointerException();
int hash = hash(key); // 把key映射到一个32位的整数
int j = (hash >>> segmentShift) & segmentMask; // segmentShift默认大小为28, segmentMask默认为15,segment数组下标
// 获取segment数组中第j个元素,若该元素为null,则对第j个Segment进行初始化
if ((s = (Segment<K,V>)UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)) == null)
s = ensureSegment(j); // 对第j个Segment进行初始化
return s.put(key, hash, value, false); // 找到对应的Segment[j],调用put
}

多个线程可能同时调用ensureSegmentSegment[j]进行初始化,在该函数中要避免重复初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // 下标K对应的内存地址的偏移量u
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // 检查下标为u的segment是否已经被初始化
Segment<K,V> proto = ss[0]; // 以Segment[0]的参数为原型
int cap = proto.table.length; // 直接使用Segment[0]的HashEntry的数组长度
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // 重新检查
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab); // 真正创建Segment
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // 自旋
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}

count表示元素个数modCount表示修改次数,当待put的元素key或hash值和链表中的某个节点相等时,不会重复插入节点,若onlyIfAbsent为false时修改该节点的value。若遍历到链表尾部,并没有发现可以或hash相等的节点,则在链表头部插入一个新节点,并把table[index]赋值为该节点。值得注意的是这里的锁是加在Segment数组的每个槽上的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
static final class Segment<K,V> extends ReentrantLock implements Serializable {
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
// 执行到该处一定要拿到锁
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash; // tab.length为2的整数次方,该处等价于hash对tab.length取模
HashEntry<K,V> first = entryAt(tab, index); // 定位到第index个HashEntry
for (HashEntry<K,V> e = first;;) {
if (e != null) {
K k;
if ((k = e.key) == key || (e.hash == hash && key.equals(k))) {
oldValue = e.value; // 若定位到相同的key
if (!onlyIfAbsent) {
e.value = value;
++modCount; // 修改次数累加
}
break; // key相等或hash值相等,不会重复插入,直接返回
}
e = e.next; // 遍历链表
} else { // 已经遍历到链表尾部,没有发现重复元素
if (node != null) // 在上面的scanAndLockForPut已经建好了节点
node.setNext(first); // 把node插入链表头部
else // 新建的node插入链表头部
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node); // 超出阈值,扩容
else
setEntryAt(tab, index, node); // 把node赋值给tab[index]
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
}
static final <K,V> HashEntry<K,V> entryAt(HashEntry<K,V>[] tab, int i) {
return (tab == null) ? null : (HashEntry<K,V>) UNSAFE.getObjectVolatile(tab, ((long)i << TSHIFT) + TBASE);
}
static final <K,V> void setEntryAt(HashEntry<K,V>[] tab, int i, HashEntry<K,V> e) {
UNSAFE.putOrderedObject(tab, ((long)i << TSHIFT) + TBASE, e);
}

在函数开始加锁时,进行了优化,若tryLock成功拿到锁,则进入下面代码,否则进入scanAndLockForPut拿不到锁不立即阻塞先自旋,若自旋到一定次数后任未拿到锁,再调用lock阻塞,且在自旋过程中遍历链表,若发现没有重复节点,则提前新建一个节点,为后面再插入节省时间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
while (!tryLock()) { // 自旋获取锁
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
if (e == null) { // 遍历到链表最后一个元素都没有key相同的
if (node == null) // 创建一个新节点
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
} else if (key.equals(e.key))
retries = 0; // 若遍历到key相同的,则不需要创建新的HashEntry,则退出遍历
else // 若first不为空,找到链表尾部
e = e.next; // 遍历链表
} else if (++retries > MAX_SCAN_RETRIES) { // 自旋,达到一定次数后,通过锁阻塞,多核为64次
lock(); // 阻塞获取锁
break;
} else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) {
// 由于是头插法,则只需要判断链表头节点是否发生变化,若发生变化则重新遍历,且只有偶数次才回去检查头结点是否发生变化
e = first = f; // 若该处值变化了,重新赋值e和first
retries = -1;
}
}
return node;
}

和HashMap一样,超过一定阈值后,Segment内部也会进行扩容,传入的节点,在扩容完成后会被插入到新的hash表中。扩容时进行了一次优化,并没有对元素依次拷贝,而是先找到lastRun位置,也就是for循环。lastRun到链表末尾的所有元素hash值没有改变,故不需要依次拷贝,只需要把这部分链表链接到新链表所对应的位置即可,也就是newTable[lastIdx] = lastRunlastRun之前的元素则需要依次拷贝。由于前面已经加了分段锁,所以不存在并发问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
int newCapacity = oldCapacity << 1; // 扩容一倍
threshold = (int)(newCapacity * loadFactor);
HashEntry<K,V>[] newTable = (HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1;
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) { // 若链表不存在,则不需要移动操作
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask; // 节点之前在第i个位置,则新hash表中一定处于i或i+oldCapacity位置
if (next == null) // 若链表只有一个节点,则直接挪到新数组中
newTable[idx] = e;
else { // Reuse consecutive sequence at same slot
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next; last != null; last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) { // 找到最后连续的且在新数组的下标都为lastIdx的头节点
lastIdx = k; // 寻找链表中最后一个hash值不等于lastIdx的元素
lastRun = last;
}
}
// 把lastRun之后的链表元素直接链接到新hash表中的lastIdx位置,在lastRun之前的所有链表元素,需要在新的位置逐个拷贝
newTable[lastIdx] = lastRun; // 将lastRun链表直接赋值到新的数组中
// Clone remaining nodes
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) { // 遍历旧的链表,将开头到lastRun的元素依次转移到新的数组中
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n); // 依然使用的头插法
}
}
}
}
int nodeIndex = node.hash & sizeMask; // 把新节点加入到新的hash表中
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}

整个get过程也就是两次hash,第一次hash计算出所在的Segment,第二次hash找到Segment中对应的HashEntry数组下标,然后遍历该位置的链表。整个读的过程没有加锁,而是使用了UNSAFE.getObjectVolatile

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE; // 第一次hash
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) { // 第二次hash
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}