blob: 34e5955fd79e5de9dc06d916de04625293d78eb0 [file] [log] [blame]
Steve Elliottca095be2022-07-25 14:26:10 +00001/*
2 * Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3 */
4
5package kotlinx.coroutines
6
7import kotlin.coroutines.*
8import kotlin.test.*
9
10class ThreadContextMutableCopiesTest : TestBase() {
11 companion object {
12 val threadLocalData: ThreadLocal<MutableList<String>> = ThreadLocal.withInitial { ArrayList() }
13 }
14
15 class MyMutableElement(
16 val mutableData: MutableList<String>
17 ) : CopyableThreadContextElement<MutableList<String>> {
18
19 companion object Key : CoroutineContext.Key<MyMutableElement>
20
21 override val key: CoroutineContext.Key<*>
22 get() = Key
23
24 override fun updateThreadContext(context: CoroutineContext): MutableList<String> {
25 val st = threadLocalData.get()
26 threadLocalData.set(mutableData)
27 return st
28 }
29
30 override fun restoreThreadContext(context: CoroutineContext, oldState: MutableList<String>) {
31 threadLocalData.set(oldState)
32 }
33
34 override fun copyForChild(): MyMutableElement {
35 return MyMutableElement(ArrayList(mutableData))
36 }
37
38 override fun mergeForChild(overwritingElement: CoroutineContext.Element): MyMutableElement {
39 overwritingElement as MyMutableElement // <- app-specific, may be another subtype
40 return MyMutableElement((mutableData.toSet() + overwritingElement.mutableData).toMutableList())
41 }
42 }
43
44 @Test
45 fun testDataIsCopied() = runTest {
46 val root = MyMutableElement(ArrayList())
47 runBlocking(root) {
48 val data = threadLocalData.get()
49 expect(1)
50 launch(root) {
51 assertNotSame(data, threadLocalData.get())
52 assertEquals(data, threadLocalData.get())
53 finish(2)
54 }
55 }
56 }
57
58 @Test
59 fun testDataIsNotOverwritten() = runTest {
60 val root = MyMutableElement(ArrayList())
61 runBlocking(root) {
62 expect(1)
63 val originalData = threadLocalData.get()
64 threadLocalData.get().add("X")
65 launch {
66 threadLocalData.get().add("Y")
67 // Note here, +root overwrites the data
68 launch(Dispatchers.Default + root) {
69 assertEquals(listOf("X", "Y"), threadLocalData.get())
70 assertNotSame(originalData, threadLocalData.get())
71 finish(2)
72 }
73 }
74 }
75 }
76
77 @Test
78 fun testDataIsMerged() = runTest {
79 val root = MyMutableElement(ArrayList())
80 runBlocking(root) {
81 expect(1)
82 val originalData = threadLocalData.get()
83 threadLocalData.get().add("X")
84 launch {
85 threadLocalData.get().add("Y")
86 // Note here, +root overwrites the data
87 launch(Dispatchers.Default + MyMutableElement(mutableListOf("Z"))) {
88 assertEquals(listOf("X", "Y", "Z"), threadLocalData.get())
89 assertNotSame(originalData, threadLocalData.get())
90 finish(2)
91 }
92 }
93 }
94 }
95
96 @Test
97 fun testDataIsNotOverwrittenWithContext() = runTest {
98 val root = MyMutableElement(ArrayList())
99 runBlocking(root) {
100 val originalData = threadLocalData.get()
101 threadLocalData.get().add("X")
102 expect(1)
103 launch {
104 threadLocalData.get().add("Y")
105 // Note here, +root overwrites the data
106 withContext(Dispatchers.Default + root) {
107 assertEquals(listOf("X", "Y"), threadLocalData.get())
108 assertNotSame(originalData, threadLocalData.get())
109 finish(2)
110 }
111 }
112 }
113 }
114
115 @Test
116 fun testDataIsCopiedForRunBlocking() = runTest {
117 val root = MyMutableElement(ArrayList())
118 val originalData = root.mutableData
119 runBlocking(root) {
120 assertNotSame(originalData, threadLocalData.get())
121 }
122 }
123
124 @Test
125 fun testDataIsCopiedForCoroutine() = runTest {
126 val root = MyMutableElement(ArrayList())
127 val originalData = root.mutableData
128 expect(1)
129 launch(root) {
130 assertNotSame(originalData, threadLocalData.get())
131 finish(2)
132 }
133 }
134}