ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

JDK 1.7 ConcurrentHashMap 源码解析

2022-06-09 14:34:15  阅读:220  来源: 互联网

标签:ConcurrentHashMap 1.7 int segment 源码 HashEntry key null Segment


作用

HashMap 在多线程环境中,扩容的时候可能会死循环;HashTable 只是简单粗暴的在方法上用 synchronized 进行同步,同一时刻,只会有一个线程获取到锁,其他线程全部阻塞(也有可能自旋),性能堪忧。所以 ConcurrentHashMap 诞生了。

结构

ConcurrentHashMap 是由 Segment 数组结构和 HashEntry 数组结构组成,HashEntry 类似于 HashMap 的内部结构(如果你还不了解 HashMap,建议先看看)。Segment 继承自 ReentrantLock(看这里),那本身也就是锁了。ConcurrentHashMap 结构图如下

image

我们再看看 HashMap 结构图

image

我们可以发现 Segment 的 HashEntry 数组类似于 HashMap 的 Entry 数组,这也难怪有人说 ConcurrentHashMap 就是一个个的 HashTable。

初始化

接下来看看 ConcurrentHashMap 的初始化逻辑

public ConcurrentHashMap(int initialCapacity,
                         float loadFactor, int concurrencyLevel) {
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
    if (concurrencyLevel > MAX_SEGMENTS)
        concurrencyLevel = MAX_SEGMENTS;
    // Find power-of-two sizes best matching arguments
    int sshift = 0;
    int ssize = 1;
    // 找到最小的 2 的 N 次方值作为 segments 数组的长度
    while (ssize < concurrencyLevel) {
        // sshift 等于 ssize 从 1 向左移位的次数
        ++sshift;
        ssize <<= 1;
    }
    // 这两个全局变量需要在定位 segment 时的散列算法里使用
    this.segmentShift = 32 - sshift;
    this.segmentMask = ssize - 1;
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    // 根据初始化容量和 segments 数组大小计算 HashEntry 数组大小    
    int c = initialCapacity / ssize;
    if (c * ssize < initialCapacity)
        ++c;
    // 默认 MIN_SEGMENT_TABLE_CAPACITY = 2;
    int cap = MIN_SEGMENT_TABLE_CAPACITY;
    while (cap < c)
        cap <<= 1;
    // 创建第一个 Segment 和 Segment 数组
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                         (HashEntry<K,V>[])new HashEntry[cap]);
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    this.segments = ss;
}

初始化的过程做了一下几件事:

  • 根据 concurrencyLevel 计算出 segments 数组大小,为大于或等于 concurrencyLevel 最小的 2 的 N 次方
  • 根据 ssize 计算出 segmentShift 和 segmentMask,用于定位 segment 使用的散列算法
  • 计算出 HashEntry 数组容量大小,默认为 2
  • 初始化第一个 Segment 和 Segments 数组

操作

get

get 方法代码如下

/**
 * key 不能为 null,否则就 NPE
 * @throws NullPointerException if the specified key is null
 */
public V get(Object key) {
    Segment<K,V> s; // manually integrate access methods to reduce overhead
    HashEntry<K,V>[] tab;
    int h = hash(key);
    // 定位 segment
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    // 判断 segment 和 segment 的 HashEntry 数组是否已存在
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
        // 定位到 HashEntry[] 中具体的 HashEntry,循环该位置上的链表
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                 (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
             e != null; e = e.next) {
            K k;
            // 判断是否找到
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}

get 方法查找逻辑如下:

  1. 根据给定的 key 定位 segment 位置,判断该位置 segment 是否存在以及 segment 的 HashEntry 数组是否存在;
  2. 定位到 segment 后,再定位 HashEntry[] 中具体的 HashEntry,然后循环该位置上的链表,直到找到指定的 key,然后返回 value;
  3. 以上都不满足,返回 null。
  4. 另外,如果 key == null,则抛 NullPointerException

get 方法还是有些需要注意的地方,当定位到 segment 后,从 segments 数组中获取该 segment 的时候用到了 UNSAFE.getObjectVolatile

s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null

顾名思义,getObjectVolatile 方法具有 volatile 的内存语义(可见性)。然后获取 segment 中 HashEntry 数组的时候:(tab = s.table) != null 并没有使用该方法,应该可以猜到,table 被 volatile 修饰了,所以能够保证获取到的是最新的

transient volatile HashEntry<K,V>[] table;

put

put 方法代码如下

/**
 * key and key 都不能为 null,否则 NPE
 * @throws NullPointerException if the specified key or value is null
 */
public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)
        throw new NullPointerException();
    int hash = hash(key);
    int j = (hash >>> segmentShift) & segmentMask;
    // 如果定位 segments 数组索引处 segment 还没初始化,则先初始化
    if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
         (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
        s = ensureSegment(j);
    return s.put(key, hash, value, false);
}

如果指定索引处的 segment 还没初始化,则先调用 ensureSegment() 方法初始化

/**
 * 创建 Segment
 */
private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE; // raw offset
    Segment<K,V> seg;
    // 判断指定索引处的 segment 是否初始化
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        // 把第一个 segment 当做原型,其它的segment属性都参考这个
        Segment<K,V> proto = ss[0]; // use segment 0 as prototype
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        // 再次判断 segment 是否创建
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) { // recheck
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                // 通过 CAS 初始化该索引处的 segment
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}

初始化 segments 数组中某一个元素的时候,reeeeeeecheck 了 N 次,最后通过 CAS 初始化。

到这里,定位 segment 的任务完成了,接下来就是真正执行 put 的时候了

return s.put(key, hash, value, false);

可以看到,执行的是定义在 Segment 里的 put 方法,代码如下

/**
 * put 方法返回的指定 key 对应的旧值(oldValue)
 */
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    // 先要获取锁
    HashEntry<K,V> node = tryLock() ? null :
        scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        HashEntry<K,V>[] tab = table;
        int index = (tab.length - 1) & hash;
        // 定位 HashEntry
        HashEntry<K,V> first = entryAt(tab, index);
        for (HashEntry<K,V> e = first;;) {
            // 定位到的 HashEntry 数组位置上已存在 HashEntry
            // 则循环链表,检索 key
            if (e != null) {
                K k;
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    oldValue = e.value;
                    if (!onlyIfAbsent) {
                        // 找到匹配的 key,并且 onlyIfAbsent 为 false
                        // 设置找到的 HashEntry value 为新值
                        e.value = value;
                        // 修改数 + 1
                        ++modCount;
                    }
                    break;
                }
                e = e.next;
            }
            // 如果定位到的位置上还没有 HashEntry
            else {
                // node != null,说明是之前获取锁失败的并且定位到的索引位还没 HashEntry
                if (node != null)
                    node.setNext(first);
                else
                    node = new HashEntry<K,V>(hash, key, value, first);
                int c = count + 1;
                // 如果 HashEntry 数组元素数量达到阈值,则进行扩容
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                    rehash(node);
                else
                    // 否则就直接初始化指定索引处的 HashEntry
                    setEntryAt(tab, index, node);
                ++modCount;
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
        // 释放锁
        unlock();
    }
    return oldValue;
}

put 方法的流程逻辑如下:

  1. 定位 segment,如果没有初始化,则先调用 ensureSegment() 方法进行初始化;
  2. 获取锁,定位 HashEntry;
  3. HashEntry 已经初始化,循环链表,检索 key,替换 value,返回旧值;
  4. HashEntry 还没初始化,则先初始化,如果 HashEntry 数组元素数量达到阈值,则先扩容,新数组大小为之前的 2 倍,并且扩容的时候,将新的 HashEntry 加到新的数组中。

上面第 2 步中,如果线程获取锁失败,将执行 scanAndLockForPut() 方法

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
    // 链表首节点
    HashEntry<K,V> first = entryForHash(this, hash);
    HashEntry<K,V> e = first;
    HashEntry<K,V> node = null;
    int retries = -1; // negative while locating node
    // 不停地重试获取锁
    while (!tryLock()) {
        HashEntry<K,V> f; // to recheck first below
        if (retries < 0) {
            // 如果首节点为 null,则新建 HashEntry
            if (e == null) {
                if (node == null) // speculatively create node
                    node = new HashEntry<K,V>(hash, key, value, null);
                retries = 0;
            }
            // 检索 key
            else if (key.equals(e.key))
                retries = 0;
            else
                e = e.next;
        }
        // 当重试次数达到一定的数量(单 CPU 1 次,其它 64 次)
        // 调用 lock():再一次获取锁,获取失败则阻塞当前线程
        else if (++retries > MAX_SCAN_RETRIES) {
            lock();
            break;
        }
        // 如果首节点发生改变,重新检索
        else if ((retries & 1) == 0 &&
                 (f = entryForHash(this, hash)) != first) {
            e = first = f; // re-traverse if entry changed
            retries = -1;
        }
    }
    return node;
}

也就是说,当线程 put 获取锁失败,则不停地尝试获取锁,直到重试的次数达到上限,如果还没获取到锁,那就被阻塞。

size

ConcurrentHashMap 的分段锁设计能够很好的支持并发操作,如果想要统计元素总数,那肯定就是将每个 Segment 里的元素个数加起来。但有个问题,累加的过程中,已累加的 segment 的元素个数可能已发生了改变,那到最后计算的总数肯定就不准确了。所以得有个参考的东西,来表示 segment 个数有没有发生变化,那就是 modCount 属性了。每次统计完总数,再比较下 modCount 是否发生改变。

size() 方法代码如下

public int size() {
    // Try a few times to get accurate count. On failure due to
    // continuous async changes in table, resort to locking.
    final Segment<K,V>[] segments = this.segments;
    int size;
    boolean overflow; // true if size overflows 32 bits
    long sum;         // sum of modCounts
    long last = 0L;   // previous sum
    int retries = -1; // first iteration isn't retry
    try {
        for (;;) {
            // 重试 3 次后,将每个 segment 加锁
            if (retries++ == RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    ensureSegment(j).lock(); // force creation
            }
            sum = 0L;
            size = 0;
            overflow = false;
            for (int j = 0; j < segments.length; ++j) {
                Segment<K,V> seg = segmentAt(segments, j);
                if (seg != null) {
                    // 统计 modCount
                    sum += seg.modCount;
                    int c = seg.count;
                    // 累加 count
                    if (c < 0 || (size += c) < 0)
                        overflow = true;
                }
            }
            // 统计期间,modCount 是否发生变化
            if (sum == last)
                break;
            last = sum;
        }
    } finally {
        // 释放锁
        if (retries > RETRIES_BEFORE_LOCK) {
            for (int j = 0; j < segments.length; ++j)
                segmentAt(segments, j).unlock();
        }
    }
    return overflow ? Integer.MAX_VALUE : size;
}

size 逻辑如下:

  1. 采用不加锁的方式统计个数,统计期间,segment 元素个数没有发生变化,则返回统计值,最多重试 3 次。
  2. 如果 3 次还不能准确统计,则对每个 segment 加锁,再次统计。

总结

ConcurrentHashMap 主要方法基本就分析完了,可以发现,其中 put 和 get 的核心思想适合 HashMap 类似的,所以在看 ConcurrentHashMap 源码之前,建议还是看下 HashMap 代码。ConcurrentHashMap 分段锁是一个很经典的设计,但是 JDK 1.8 中又完全摒弃了这种思想,所以下一篇应该就是 JDK 1.8 源码解析了~

标签:ConcurrentHashMap,1.7,int,segment,源码,HashEntry,key,null,Segment
来源: https://www.cnblogs.com/tailife/p/16359229.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有