ConcurrentHashMap源码

2023/03/07

本篇基于Java11

由于ConcurrentHashMapHashMap复杂了不止一点,个人建议自己跟着源码追,有看不懂的方法再来看我的博客。

省流版:

  • loadFactor(负载因子)被固定为了0.75,且通过构造器设置只会影响初始容量
  • ConcurrentHashMap支持并发扩容
  • HashMapthreshold被替换为了sizeCtl,高16位代表当前哈希表容量的一个"版本号",低16位 - 1表示当前正在进行扩容的线程数

1. 构造器

public ConcurrentHashMap() { } public ConcurrentHashMap(int initialCapacity) { this(initialCapacity, LOAD_FACTOR, 1); } public ConcurrentHashMap(Map<? extends K, ? extends V> m) { this.sizeCtl = DEFAULT_CAPACITY; putAll(m); } public ConcurrentHashMap(int initialCapacity, float loadFactor) { this(initialCapacity, loadFactor, 1); } public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) { if (!(loadFactor > 0.0f) || initialCapacity < 0 || concurrencyLevel <= 0) throw new IllegalArgumentException(); if (initialCapacity < concurrencyLevel) // Use at least as many bins initialCapacity = concurrencyLevel; // as estimated threads long size = (long)(1.0 + (long)initialCapacity / loadFactor); int cap = (size >= (long)MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : tableSizeFor((int)size); this.sizeCtl = cap; }
java

和我们的老朋友HashMap有如下不同:

  • HashMapthreshold扩容阈值没有了

  • 这里新出来了一个sizeCtl,根据其大小有不同的含义:

    • 当该值为0时,表示正在等待哪个线程去初始化
    • 当该值为负数时,表示正在初始化或者调整大小
      • -1表示正在初始化
      • 若为其他负数,低16位 - 1表示当前正在扩容的线程数,是的,你没看错,ConcurrentHashMap是支持并发扩容的!
    • 当大于0时,这个值就和HashMapthreshold一样的意思了,都是代表扩容阈值
  • 构造器多了一个concurrencyLevel,表示预估会有多少个写线程,实际上也没什么用,就是一个局部变量。

2. put

在看put之前还需要了解其它一些方法。

2.1 initTable

/** * Initializes table, using the size recorded in sizeCtl. */ private final Node<K,V>[] initTable() { Node<K,V>[] tab; int sc; while ((tab = table) == null || tab.length == 0) { // 前面说过了,sizeCtl小于0表示正在初始化,这里说明有别的线程正在初始化 if ((sc = sizeCtl) < 0) // 让当前线程主动放弃CPU Thread.yield(); // lost initialization race; just spin // 在这里进行CAS修改sizeCtl为-1 else if (U.compareAndSetInt(this, SIZECTL, sc, -1)) { try { if ((tab = table) == null || tab.length == 0) { // n表示新哈希表容量,默认容量还是16 int n = (sc > 0) ? sc : DEFAULT_CAPACITY; @SuppressWarnings("unchecked") Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n]; table = tab = nt; // 这里可以理解为 sc = n * 0.75. 向右位移2位等于除以4 sc = n - (n >>> 2); } } finally { sizeCtl = sc; } break; } } return tab; }
java

2.2 tabAt

static final <K,V> Node<K,V> tabAt(Node<K,V>[] tab, int i) { return (Node<K,V>)U.getObjectAcquire(tab, ((long)i << ASHIFT) + ABASE); } public final Object getObjectAcquire(Object o, long offset) { return getObjectVolatile(o, offset); }
java

看名字就很容易能看懂是干嘛的了,因为我们不能给数组的某个元素加volatile,所以只能用这种方式保证可见性。

这里的ASHIFT和ABASE可以在源码的最后的静态代码块处看到,这里我将其提了出来:

static final Unsafe U = createUnsafe(); public static Unsafe createUnsafe() { try { Class<?> unsafeClass = Class.forName("sun.misc.Unsafe"); Field field = unsafeClass.getDeclaredField("theUnsafe"); field.setAccessible(true); return (Unsafe) field.get(null); } catch (Exception e) { e.printStackTrace(); } return null; } static class Node {} public static void main(String[] args) { int[] arr = new int[10]; int scale = U.arrayIndexScale(Node[].class); int ABASE = U.arrayBaseOffset(Node[].class); int ASHIFT = 31 - Integer.numberOfLeadingZeros(scale); System.out.println("ABASE = " + ABASE + ", ASHIFT = " + ASHIFT + ", scale = " + scale); scale = U.arrayIndexScale(long[].class); ABASE = U.arrayBaseOffset(long[].class); ASHIFT = 31 - Integer.numberOfLeadingZeros(scale); System.out.println("ABASE = " + ABASE + ", ASHIFT = " + ASHIFT + ", scale = " + scale); }
java

输出:

ABASE = 16, ASHIFT = 2, scale = 4 ABASE = 16, ASHIFT = 3, scale = 8
text

这里的scale很明显,代表数组每个元素的占用大小。

ABASE则是元素在数组里的偏移值,一般大小为:对象头(8字节) + 类型指针(默认4字节,关闭指针压缩后为8字节) + 数组长度(4字节) = 16,如果你不清楚我在说什么,可以去看一下我的这篇博客:对象在内存中的存储布局 (notion.so)

ASHIFT则是表示当前元素占用大小二进制的1右边有多少个0,在ConcurrentHashMap初始化时,如果scale不是2的幂则会报错。

那么在取值的时候是什么意思呢?

其实这里很像C的指针了,ABASE代表基础偏移值,而i << ASHIFT则每个元素的位置:

  • 比如i为0,i << ASHIFT = 0,代表这个元素在对象在堆中的地址 + ABASE

  • 比如i为1,i << ASHIFT = 4,代表这个元素在对象在堆中的地址 + ABASE + 4

2.2 putVal

put内部其实就是调用了putVal

final V putVal(K key, V value, boolean onlyIfAbsent) { if (key == null || value == null) throw new NullPointerException(); // 在这里将key的哈希高16位和低16位进行异或后得到新的哈希 int hash = spread(key.hashCode()); int binCount = 0; // 看到这层死循环就应该感觉到会有CAS出现 for (Node<K,V>[] tab = table;;) { Node<K,V> f; int n, i, fh; K fk; V fv; if (tab == null || (n = tab.length) == 0) // 初始化表 tab = initTable(); // 这里取哈希索引和HashMap一样 else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) { // 这个位置没有元素,则尝试CAS放进去 if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value))) break; // no lock when adding to empty bin } // 这里后面再讲 else if ((fh = f.hash) == MOVED) tab = helpTransfer(tab, f); // 在这里判断是否需要替换掉原节点 else if (onlyIfAbsent // check first node without acquiring lock && fh == hash && ((fk = f.key) == key || (fk != null && key.equals(fk))) && (fv = f.val) != null) return fv; else { V oldVal = null; // 锁住当前节点 synchronized (f) { // 确保这个没有被修改 if (tabAt(tab, i) == f) { // 这里要知道,如果节点是红黑树,哈希值为-2 if (fh >= 0) { // 这个值代表链表的大小 binCount = 1; for (Node<K,V> e = f;; ++binCount) { K ek; // 这里同样也是在尝试进行替换:判断hash和key是否相等 if (e.hash == hash && ((ek = e.key) == key || (ek != null && key.equals(ek)))) { oldVal = e.val; if (!onlyIfAbsent) e.val = value; break; } // 这里检查是否遍历到链表尾部,如果到尾部了则直接插入 Node<K,V> pred = e; if ((e = e.next) == null) { pred.next = new Node<K,V>(hash, key, value); break; } } } // 如果是红黑树 else if (f instanceof TreeBin) { Node<K,V> p; binCount = 2; // 在红黑树里进行查找 if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key, value)) != null) { oldVal = p.val; if (!onlyIfAbsent) p.val = value; } } // 这个后面讲 else if (f instanceof ReservationNode) throw new IllegalStateException("Recursive update"); } } // 这里判断是否需要将链表树化 if (binCount != 0) { if (binCount >= TREEIFY_THRESHOLD) treeifyBin(tab, i); if (oldVal != null) return oldVal; break; } } } // 进行统计,同时在这里判断是否需要扩容 addCount(1L, binCount); return null; }
java

2.3 addCount

这个方法比较复杂,先讲一些其它的小方法。

2.3.1 resizeStamp

static final int resizeStamp(int n) { return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1)); }
java

这个方法主要是获取n最高位前面有几个0,然后和后面的值相与。RESIZE_STAMP_BITS为常量:16

那么右边整体就是个常量:(1 << 15) 即1的右边15个0。

那么这个方法有什么用呢?其实这里是用作来生成一个扩容标记的,相当于一个版本号。

至于这么玩有什么用,正常看源码到这里是不知道的,我们先接着往后看。

2.3.2 transfer

Moves and/or copies the nodes in each bin to new table. See above for explanation.

这个方法大致就是将节点从旧哈希表复制或者移动到新的哈希表中,方法很长。

源码里又出现了两个新的类变量:

  • nextTable:表示新的哈希表,仅在扩容时非空。
  • transferIndex:The next table index (plus one) to split while resizing. 这里不是很好理解,就不翻译了,看源码就能懂。

同时这里用到了一个新的节点:ForwardingNode,这个节点继承了基础的Node节点,但是其hash值永远为MOVED,即为-1。同时,内部还保存了新的哈希表nextTable。根据文档翻译,这个节点是用作一个头结点,作为新哈希表的表头(A node inserted at head of bins during transfer operations.)。

transfer是并发扩容的实现,对于每个线程,每次会分配一块固定长度大小的区域来让线程对tab进行重新hash,这个区域的大小与CPU核心数成反比,但最小为16。

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) { int n = tab.length, stride; // 将n/8/CPU核心数当做区域大小,最小值为16 if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE) stride = MIN_TRANSFER_STRIDE; // subdivide range // nextTab为空,表示当前线程是第一个进行扩容的线程 if (nextTab == null) { // initiating try { @SuppressWarnings("unchecked") Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1]; nextTab = nt; } catch (Throwable ex) { // try to cope with OOME // OOM的fallback sizeCtl = Integer.MAX_VALUE; return; } nextTable = nextTab; transferIndex = n; } int nextn = nextTab.length; // 用于占位。当别的线程发现这个槽位中是 fwd 类型的节点,则跳过这个节点。 ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab); boolean advance = true; boolean finishing = false; // to ensure sweep before committing nextTab for (int i = 0, bound = 0;;) { Node<K,V> f; int fh; while (advance) { // 1. 因为advance一开始肯定为true,所以进入这个循环 int nextIndex, nextBound; // --i相当于从大到小去分配区域 if (--i >= bound || finishing) advance = false; // 获取当前已经分配到的索引 else if ((nextIndex = transferIndex) <= 0) { // 小于0表示已经扩容完了 i = -1; advance = false; } // 尝试接下这块区域 else if (U.compareAndSetInt (this, TRANSFERINDEX, nextIndex, nextBound = (nextIndex > stride ? nextIndex - stride : 0))) { // 到这里,说明当前线程已经接下了[nextBound, nextIndex - 1]这块区域重新分配的任务 bound = nextBound; i = nextIndex - 1; advance = false; } } // i小于0,说明线程没拿到任务 // 至于后面那俩,我们仔细观察即可发现:这两个发生的条件是某一个线程进行扩容时,其它线程已经扩容完了 // 并且又开启了新一轮的扩容 if (i < 0 || i >= n || i + n >= nextn) { int sc; // 判断扩容已经完成 if (finishing) { nextTable = null; table = nextTab; // 这里重新分配扩容阈值,负载因子为0.75 sizeCtl = (n << 1) - (n >>> 1); return; } // 将sizeCtl - 1,在开头我们已经说了,sizeCtl低16位保存当前正在进行扩容的线程数量 if (U.compareAndSetInt(this, SIZECTL, sc = sizeCtl, sc - 1)) { // 这里判断是否只有一个线程在扩容 if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT) // 只有一个线程扩容,且没有接到任务,说明扩容完成了 return; // 还有别的线程在扩容,给个标记后再重新检查一遍。。 finishing = advance = true; i = n; // recheck before commit } } // 如果任务区间最后一个为空 else if ((f = tabAt(tab, i)) == null) // 尝试CAS将其赋值为占位节点 advance = casTabAt(tab, i, null, fwd); // 如果为true,表示有其它人跟自己一样分配到了一样的任务,需要重新分配任务. // 注意这里有--i,不用担心和上一次一样取到一样的hash else if ((fh = f.hash) == MOVED) advance = true; // already processed else { // 锁住尾节点 synchronized (f) { // 再检查一遍 if (tabAt(tab, i) == f) { Node<K,V> ln, hn; if (fh >= 0) { // 链表转移. 这里貌似用的是尾插法. int runBit = fh & n; Node<K,V> lastRun = f; for (Node<K,V> p = f.next; p != null; p = p.next) { int b = p.hash & n; if (b != runBit) { runBit = b; lastRun = p; } } if (runBit == 0) { ln = lastRun; hn = null; } else { hn = lastRun; ln = null; } for (Node<K,V> p = f; p != lastRun; p = p.next) { int ph = p.hash; K pk = p.key; V pv = p.val; if ((ph & n) == 0) ln = new Node<K,V>(ph, pk, pv, ln); else hn = new Node<K,V>(ph, pk, pv, hn); } setTabAt(nextTab, i, ln); setTabAt(nextTab, i + n, hn); setTabAt(tab, i, fwd); advance = true; } else if (f instanceof TreeBin) { // 树节点转移 TreeBin<K,V> t = (TreeBin<K,V>)f; TreeNode<K,V> lo = null, loTail = null; TreeNode<K,V> hi = null, hiTail = null; int lc = 0, hc = 0; for (Node<K,V> e = t.first; e != null; e = e.next) { int h = e.hash; TreeNode<K,V> p = new TreeNode<K,V> (h, e.key, e.val, null, null); if ((h & n) == 0) { if ((p.prev = loTail) == null) lo = p; else loTail.next = p; loTail = p; ++lc; } else { if ((p.prev = hiTail) == null) hi = p; else hiTail.next = p; hiTail = p; ++hc; } } ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) : (hc != 0) ? new TreeBin<K,V>(lo) : t; hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) : (lc != 0) ? new TreeBin<K,V>(hi) : t; setTabAt(nextTab, i, ln); setTabAt(nextTab, i + n, hn); setTabAt(tab, i, fwd); advance = true; } } } } } }
java

2.3.3 总结

了解过上面两个方法后,再来看addCount就清晰多了

private final void addCount(long x, int check) { CounterCell[] cs; long b, s; if ((cs = counterCells) != null || // 这里去cas设置总节点数量 !U.compareAndSetLong(this, BASECOUNT, b = baseCount, s = b + x)) { // 这里就使用了类似LongAdder和Striped64的设计,将自增分散到多个格子里 CounterCell c; long v; int m; boolean uncontended = true; if (cs == null || (m = cs.length - 1) < 0 || (c = cs[ThreadLocalRandom.getProbe() & m]) == null || !(uncontended = U.compareAndSetLong(c, CELLVALUE, v = c.value, v + x))) { fullAddCount(x, uncontended); return; } if (check <= 1) return; // 在这里重新获取总节点数量 s = sumCount(); } // 这s和b已经赋过值了 // 这里主要判断是否需要当前线程去扩容或协助扩容 if (check >= 0) { Node<K,V>[] tab, nt; int n, sc; // 又是循环,CAS的小曲 // 这里主要判断容量是否大于等于sizeCtl,然后进行扩容 while (s >= (long)(sc = sizeCtl) && (tab = table) != null && (n = tab.length) < MAXIMUM_CAPACITY) { // 后面讲 int rs = resizeStamp(n); // 判断是否有其它线程正在修改 if (sc < 0) { // 如果有,判断当前容量是否已经发生改变 // sc == rs + 1是用来判断当前是否没有线程在进行扩容 // 后面的都是用来判断扩容是否已经完成了,不需要当前线程进行协助扩容 if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 || sc == rs + MAX_RESIZERS || (nt = nextTable) == null || transferIndex <= 0) break; // 协助进行扩容,将sizeCtl + 1,表示多了一个线程在进行扩容 if (U.compareAndSetInt(this, SIZECTL, sc, sc + 1)) transfer(tab, nt); } // 这里CAS修改sizeCtl为(rs << RESIZE_STAMP_SHIFT) + 2),表示当前线程是第一个发起扩容的 else if (U.compareAndSetInt(this, SIZECTL, sc, (rs << RESIZE_STAMP_SHIFT) + 2)) transfer(tab, null); s = sumCount(); } } }
java

关于rs << RESIZE_STAMP_SHIFT可能不好理解,这里我们举个例子:

int n = 16; int i = Integer.numberOfLeadingZeros(n) | (1 << 15); // System.out.println(Integer.toBinaryString()); System.out.println(Integer.toBinaryString(i)); System.out.println(Integer.toBinaryString(i << 16));
java

输出:

1000000000011011 10000000000110110000000000000000
text

我们在开头也说过了,ctl高16位代表一个版本号,第16位然后再减一代表当前正在扩容的线程数,所有这里就代表当前有一个线程正在进行扩容。

2.4 sumCount

这里就是类似于LongAdder一样,获取元素总数量,size方法也是调用的这个:

final long sumCount() { CounterCell[] cs = counterCells; long sum = baseCount; if (cs != null) { for (CounterCell c : cs) if (c != null) sum += c.value; } return sum; } public int size() { long n = sumCount(); return ((n < 0L) ? 0 : (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int)n); }
java