ThreadLocal原理

ThreadLocal是一个线程内部的存储类,可以在指定线程内存储数据,数据存储以后,只有指定线程可以得到存储数据。提供了线程内存储变量的能力,这些变量不同之处在于每一个线程读取的变量是对应的互相独立的

单看ThreadLocal类的源码其实很简单,对外提供的方法也很少。复杂的点在于内部静态类ThreadLocalMap每个线程持有一个ThreadLocalMap对象,每一个新的线程Thread都会实例化一个ThreadLocalMap并赋值给成员变量threadLocals

原理

HASH_INCREMENT魔数的选取与斐波那契散列有关,用0x61c88647作为魔数累加为每个ThreadLocal分配各自的ID也就是threadLocalHashCode再与2的幂取模,得到的结果分布很均匀。ThreadLocalMap使用的是线性探测法,均匀分布的好处在于很快就能探测到下一个临近的可用slot,从而保证效率。

1
2
3
4
5
6
7
private final int threadLocalHashCode = nextHashCode();
private static AtomicInteger nextHashCode = new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;

private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}

initialValuewithInitial两个方法是使用的时候用于重写赋初始值。withInitial通过lambda表达式的方式来重写initialValue方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
protected T initialValue() {
return null;
}

public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}

static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}

@Override
protected T initialValue() {
return supplier.get();
}
}

通过ThreadLocalget源码可以看到,数据是存储在ThreadLocalMap中,而具体的ThreadLocalMap实例并不是ThreadLocal保持,而是保持在每个Thread持有的成员变量threadLocals中。不同的Thread持有不同的ThreadLocalMap实例,因此它们是不存在线程竞争。每次线程死亡,所有map中引用到的对象都会随着这个Thread的死亡而被垃圾收集器一起收集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

set方法跟上面的setInitialValue差不多。如果当前线程的ThreadLocalMap为空就创建一个新的ThreadLocalMap并赋值给当前线程的成员变量threadLocals,否则set当前值。

1
2
3
4
5
6
7
8
9
10
11
12
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap

ThreadLocalMapThreadLocal的静态内部类为每个Thread都维护了一个数组tableThreadLocal确定了一个数组下标,而这个下标就是value存储的对应位置。

实例化ThreadLocalMap时创建了一个初始长度16Entry数组,且数组长度始终为2的幂。与HashMap类似通过hashCodelength位运算确定数组下标。结合此处的构造方法可以理解成每个线程Thread都持有一个Entry型的数组table,而一切的读取过程都是通过操作这个数组table完成的。

1
2
3
4
private static final int INITIAL_CAPACITY = 16;
private Entry[] table;
private int size = 0;
private int threshold;

为了解决内存回收,这里的Entry继承了WeakReference弱引用。ThreadLocalMap使用ThreadLocal的弱引用作为key,如果一个ThreadLocal没有外部强引用引用他,系统gc的时候,该ThreadLocal势必会被回收。

Entrykey是对ThreadLocal的弱引用,当抛弃掉ThreadLocal对象时,垃圾收集器会忽略这个key的引用而清理掉ThreadLocal对象, 防止了内存泄漏。

1
2
3
4
5
6
7
8
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

i是通过对threadLocalHashCode的取模得到数组的下标,将构建的Entry放到table数组中。并设置thresholdThreadLocalMap有两个方法用于得到上一个/下一个索引,从nextIndexprevIndex两个方法的实现上来看,Entry数组在程序逻辑上是作为一个环形存在的,这也是由于ThreadLocalMap是使用线性探测法来解决散列冲突

线性探测法:往哈希表中插入数据时,通过哈希函数计算该值的哈希地址,若发现该位置已有数据,则找紧跟着的下一个位置,若无数据则插入,若有数据则继续探测下一个位置。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

private void setThreshold(int len) {
threshold = len * 2 / 3;
}

private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

getEntry方法在ThreadLocal中的get方法中被调用,首先从ThreadLocal直接索引位置获取Entry e,若e不为null并且key相同则返回e,若enull直接返回null,若e不为nullkey不一致则向下一个位置查询,如果下一个位置的key和当前需要查询的key相等,则返回对应的Entry,否则若key值为null,则擦除该位置的Entry,返回null,否则继续向下一个位置查询,直到enull还没有找到对应的Entry则返回null

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

// 由于使用的是线性探索,往后还可能找到目标Entry
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

Entrykey为空时通过expungeStaleEntry方法擦除该位置的Entry。防止该Entryvalue被一直强引用从而导致内存泄露。ThreadLocal引用示意图如下,实线表示强引用,虚线表示弱引用

ThreadLocal引用示意图

从图中可看到,虽然EntrykeyThreadLocal的弱引用,key在其他地方没有强引用时即会被回收,但是Entryvalue会一直被引用,不能得到释放。当然若线程执行结束threadLocalthreadRef会断掉,因此threadLocalthreadLocalMapentry都会被回收,但实际中为了线程复用我们会使用线程池,threadRef可能永远不会断掉,可能导致value永远无法回收。所以这里是直接将tableEntryEntryvalue的引用置空。

for循环是往后环形查找,直到遇到table[i] == null时结束,k == null表示再次遇到脏Entry同样将其清理掉。k != nullh != i表示处理rehash的情况,将起挪到hash下标为h开始的第一个为空的位置。

注:脏Entry仅仅是keynull,而不是通过table[i]获取的Entry为空。这也是为什么遇到tab[i] == null就退出搜索了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
/*
* 对于还没有被回收的情况,需要做一次rehash。
* 如果对应的ThreadLocal的ID对len取模出来的索引h不为当前位置i,
* 则从h向后线性探测到第一个空的slot,把当前的entry给挪过去。
*/
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
/*
* ThreadLocalMap因为使用了弱引用,所以其实每个slot的状态有三种也即
* 有效(value未回收),无效(value已回收),空(entry==null)。
* 正是因为ThreadLocalMap的entry有三种状态,所以不能完全套高德纳原书的R算法。
*
* 因为expungeStaleEntry函数在扫描过程中还会对无效slot清理将之转为空slot,
* 如果直接套用R算法,可能会出现具有相同哈希值的entry之间断开(中间有空entry)。
*/
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

若当前table[i] != null说明hash冲突就需要向后环形查找,若查找过程中遇到脏entry就通过replaceStaleEntry进行处理;若当前table[i] == null说明新的entry可以直接插入,但是插入后会调用cleanSomeSlots方法检测并清除脏entry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;
return;
}

// 替换失效的entry
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

replaceStaleEntry并不仅仅局限于处理当前已知的Entry,首先通过for循环向前找到第一个Entry,这里的第一个是指向前查找遇到的最靠近table[i] == nullEntry,它认为在出现脏Entry相邻位置也有很大概率出现脏Entry,为了一次处理到位,就需要向前环形搜索,找到前面的脏Entry

根据向前搜索中是否有脏Entry以及在for循环向后环形查找中是否找到可覆盖的Entry,,最后使用cleanSomeSlots方法从slotToExpunge为起点开始进行清理脏Entry,可分四种情况:

  • 前向Entry,向后环形查找找到可覆盖的Entry

  • 前向Entry,向后环形查找未找到可覆盖的Entry

  • 前向Entry,向后环形查找找到可覆盖的Entry

  • 前向Entry,向后环形查找未找到可覆盖的Entry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// 向前扫描,查找最前的一个无效slot
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// 向后遍历table
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;

/*
* 如果在整个扫描过程中(包括函数一开始的向前扫描与i之前的向后扫描)
* 找到了之前的无效slot则以那个位置作为清理的起点,否则以当前的i作为清理起点
*/
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 从slotToExpunge开始做一次连续段的清理,再做一次启发式清理
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// 如果当前的slot已经无效,并且向前扫描过程中没有无效slot,则更新slotToExpunge为当前位置
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// 如果key在table中不存在,则在原地放一个即可
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

// 在探测过程中如果发现任何无效slot,则做一次清理(连续段清理+启发式清理)
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

参数n主要用于扫描次数控制,若没有遇到脏Entry,整个扫描过程持续log2(n)次,若遇到脏Entryn重置为当前hash表的长度,再扫描log2(n)次,注意这里的nextIndex获取的数组下标是一个环。遇到脏Entry时通过expungeStaleEntry清理脏Entry

cleanSomeSlots执行情景图

若当前n等于hash表的sizen=10i=1,在第一趟搜索过程中通过nextIndexi指向索引为2的位置,此时table[2]null,则第一趟未发现脏Entry,进行第二趟搜索。

第二趟搜索先通过nextIndex方法,table[3] != null找到脏Entry,先将n置为哈希表的长度len,然后继续调用expungeStaleEntry方法,将当前索引为3的脏Entry给清除掉,令valuenulltable[3]null,但该方法会继续往后环形搜索,往后发现索引4、5的位置的Entry同样为脏Entry,索引6Entry不是脏Entry保持不变,直至i=7时此处table[7]null,返回索引7

继续向后环形搜索,直到在整个搜索范围里都未发现脏EntrycleanSomeSlot方法执行结束退出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 启发式地清理slot
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
// 扩大扫描控制因子
n = len;
removed = true;
// 清理一个连续段
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

当调用set方法时发现需要扩容时,会调用rehash方法对table进行扩容。扩容前回先清理掉hash表中所有的脏Entry,然后在进行扩容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
private void rehash() {
// 做一次全量清理
expungeStaleEntries();

/*
* 因为做了一次清理,所以size很可能会变小。
* ThreadLocalMap这里的实现是调低阈值来判断是否需要扩容,
* threshold默认为len*2/3,所以这里的threshold - threshold / 4相当于len/2
*/
if (size >= threshold - threshold / 4)
resize();
}

private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
// 线性探测来存放Entry
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

setThreshold(newLen);
size = count;
table = newTab;
}

private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}

调用threadLocal.remove方法时候,实际上会调用threadLocalMapremove方法,当遇到了keynull的脏entry的时候,也会调用expungeStaleEntry清理掉脏entry

1
2
3
4
5
6
7
8
9
10
11
12
13
14
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}

threadLocal生命周期里,针对threadLocal存在的内存泄漏的问题,都会通过expungeStaleEntrycleanSomeSlots,replaceStaleEntry这三个方法清理掉keynull的脏entry

InheritableThreadLocal

1
2
3
4
5
6
7
8
9
10
11
12
13
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
protected T childValue(T parentValue) {
return parentValue;
}

ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}

void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}

InheritableThreadLocal提供了一种父子线程之间的数据共享机制,在线程初始化init时,会调用ThreadLocalcreateInheritedMap从父线程的inheritableThreadLocals中把有效的entry都拷过来。InheritableThreadLocal只是在子线程创建时会去拷一份父线程的inheritableThreadLocals。若父线程是在子线程创建后再set某个InheritableThreadLocal对象的值,对子线程是不可见的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}

private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

应用

在使用ThreadLocal时很可能出现数据错乱,这是由于通过线程池时,线程池对线程进行了复用,从而导致ThreadLocal中的数据串了。用完及时清理数据。在Web环境中可以自定义HandlerInterceptorAdapter,在preHandler中去设置ThreadLocal,在afterCompletion时区remove

重写initialValue方法

重写initialValue赋初值方法。

1
2
3
4
5
6
7
8
static ThreadLocal<Long> threadLocal = new ThreadLocal<Long>() {
@Override
protected Long initialValue() {
return Thread.currentThread().getId();
}
};

static ThreadLocal<Long> threadLocal = ThreadLocal.withInitial(() -> Thread.currentThread().getId());

MDC

MDC主要是用于将某个或某些所有日志中都需要打印的字符串设置于MDC中,这样就不需要每次打印日志时专门写出来,这里也是通过ThreadLocal来实现的。

1
2
3
4
5
6
7
8
9
10
11
12
static MDCAdapter mdcAdapter;
mdcAdapter.put(key, val);
mdcAdapter.get(key);

public class Log4jMDCAdapter implements MDCAdapter {
public Log4jMDCAdapter() {
}

public void put(String key, String val) {
ThreadContext.put(key, val);
}
}

DefaultThreadContextMap可以看到localMap其实就是一个ThreadLocal

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public final class ThreadContext {
private static ThreadContextMap contextMap;
public static void put(String key, String value) {
contextMap.put(key, value);
}
}
public class DefaultThreadContextMap implements ThreadContextMap {
public static final String INHERITABLE_MAP = "isThreadContextMapInheritable";
private final boolean useMap;
private final ThreadLocal<Map<String, String>> localMap;

public DefaultThreadContextMap(boolean useMap) {
this.useMap = useMap;
this.localMap = createThreadLocalMap(useMap);
}

static ThreadLocal<Map<String, String>> createThreadLocalMap(final boolean isMapEnabled) {
PropertiesUtil managerProps = PropertiesUtil.getProperties();
boolean inheritable = managerProps.getBooleanProperty("isThreadContextMapInheritable");
return (ThreadLocal)(inheritable ? new InheritableThreadLocal<Map<String, String>>() {
protected Map<String, String> childValue(Map<String, String> parentValue) {
return parentValue != null && isMapEnabled ? Collections.unmodifiableMap(new HashMap(parentValue)) : null;
}
} : new ThreadLocal());
}

public void put(String key, String value) {
if (this.useMap) {
Map<String, String> map = (Map)this.localMap.get();
Map<String, String> map = map == null ? new HashMap() : new HashMap(map);
map.put(key, value);
this.localMap.set(Collections.unmodifiableMap(map));
}
}
}

数据库连接、 Session 管理

1
2
3
4
5
6
7
8
9
10
11
12
13
private static final ThreadLocal threadSession = new ThreadLocal();
public static Session getSession() throws InfrastructureException {
Session s = (Session) threadSession.get();
try {
if (s == null) {
s = getSessionFactory().openSession();
threadSession.set(s);
}
} catch (HibernateException ex) {
throw new InfrastructureException(ex);
}
return s;
}

使用ThreadLocal代替锁

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
static HashSet<Val<Integer>> set = new HashSet<>();

synchronized static void addSet(Val<Integer> val) {
set.add(val);
}

static ThreadLocal<Val<Integer>> c = new ThreadLocal<Val<Integer>>(){
@Override
protected Val<Integer> initialValue() {
Val<Integer> v = new Val<Integer>();
v.set(0);
addSet(v);
return v;
}
};
// 统计结果
set.stream().map(Val::get).reduce((a, sum) -> a + sum).get()

public class Val<T> {
T val;
public T get() {
return val;
}
public void set(T val) {
this.val = val;
}
}

自实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public class MyThreadLocal<T> {
static AtomicInteger atomic = new AtomicInteger();

private int threadLocalHash = atomic.addAndGet(0x61c88647);

static HashMap<Thread, HashMap<Integer, Object>> threadLocalHashMap = new HashMap<>();

synchronized static HashMap<Integer, Object> getMap() {
Thread thread = Thread.currentThread();
if (!threadLocalHashMap.containsKey(thread)) {
threadLocalHashMap.put(thread, new HashMap<>());
}
return threadLocalHashMap.get(thread);
}


protected T initialValue() {
return null;
}

public T get() {
HashMap<Integer, Object> map = getMap ();
if (!map.containsKey(this.threadLocalHash)) {
map.put(this.threadLocalHash, initialValue());
}
return (T) map.get(this.threadLocalHash);
}

public void set(T value) {
HashMap<Integer, Object> map = getMap ();
map.put(this.threadLocalHash, value);
}
}

与Synchronized的区别

相同:都是为了解决多线程中相同变量的访问冲突问题,

不同:Synchronized同步机制是通过以时间换空间的方式控制线程访问共享对象的顺序,而threadLocal是以空间换时间为每一个线程分配一个该对象各用各的互不影响。