Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods (#1043)
Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods
Fixes #1028
diff --git a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt
index 2705acc..21c473a 100644
--- a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt
+++ b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt
@@ -481,6 +481,8 @@
public final class kotlinx/coroutines/ThreadContextElementKt {
public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
+ public static final fun ensurePresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
+ public static final fun isPresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}
public final class kotlinx/coroutines/ThreadPoolDispatcherKt {
diff --git a/docs/coroutine-context-and-dispatchers.md b/docs/coroutine-context-and-dispatchers.md
index 00b0db9..29da4b4 100644
--- a/docs/coroutine-context-and-dispatchers.md
+++ b/docs/coroutine-context-and-dispatchers.md
@@ -635,7 +635,7 @@
threadLocal.set("main")
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
- println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
+ println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
yield()
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
}
@@ -664,6 +664,10 @@
<!--- TEST FLEXIBLE_THREAD -->
+Note how easily one may forget the corresponding context element and then still safely access thread local.
+To avoid such situations, it is recommended to use [ensurePresent] method
+and fail-fast on improper usages.
+
`ThreadLocal` has first-class support and can be used with any primitive `kotlinx.coroutines` provides.
It has one key limitation: when thread-local is mutated, a new value is not propagated to the coroutine caller
(as context element cannot track all `ThreadLocal` object accesses) and updated value is lost on the next suspension.
@@ -701,5 +705,6 @@
[MainScope()]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-main-scope.html
[Dispatchers.Main]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-dispatchers/-main.html
[asContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/as-context-element.html
+[ensurePresent]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/ensure-present.html
[ThreadContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-thread-context-element/index.html
<!--- END -->
diff --git a/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt b/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt
index bfb5cfd..cf82318 100644
--- a/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt
+++ b/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt
@@ -11,9 +11,7 @@
import org.junit.*
import org.junit.Assert.*
import org.junit.Test
-import java.io.*
import java.util.concurrent.*
-import kotlin.test.assertFailsWith
class ListenableFutureTest : TestBase() {
@Before
diff --git a/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt b/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt
index 7038363..7d128c6 100644
--- a/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt
+++ b/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt
@@ -16,7 +16,6 @@
import kotlin.concurrent.*
import kotlin.coroutines.*
import kotlin.reflect.*
-import kotlin.test.assertFailsWith
class FutureTest : TestBase() {
@Before
diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
index c68ee45..4e8b6cc 100644
--- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
+++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
@@ -135,3 +135,40 @@
*/
public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
ThreadLocalElement(value, this)
+
+/**
+ * Return `true` when current thread local is present in the coroutine context, `false` otherwise.
+ * Thread local can be present in the context only if it was added via [asContextElement] to the context.
+ *
+ * Example of usage:
+ * ```
+ * suspend fun processRequest() {
+ * if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing
+ * // Do some heavy-weight tracing
+ * }
+ * // Process request regularly
+ * }
+ * ```
+ */
+public suspend inline fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] !== null
+
+/**
+ * Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not.
+ * It is a good practice to validate that thread local is present in the context, especially in large code-bases,
+ * to avoid stale thread-local values and to have a strict invariants.
+ *
+ * E.g. one may use the following method to enforce proper use of the thread locals with coroutines:
+ * ```
+ * public suspend inline fun <T> ThreadLocal<T>.getSafely(): T {
+ * ensurePresent()
+ * return get()
+ * }
+ *
+ * // Usage
+ * withContext(...) {
+ * val value = threadLocal.getSafely() // Fail-fast in case of improper context
+ * }
+ * ```
+ */
+public suspend inline fun ThreadLocal<*>.ensurePresent(): Unit =
+ check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" }
diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
index 7dafb47..375dc60 100644
--- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
+++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
@@ -98,7 +98,8 @@
}
// top-level data class for a nicer out-of-the-box toString representation and class name
-private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
+@PublishedApi
+internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
internal class ThreadLocalElement<T>(
private val value: T,
diff --git a/kotlinx-coroutines-core/jvm/test/TestBase.kt b/kotlinx-coroutines-core/jvm/test/TestBase.kt
index db5c53a..6fef760 100644
--- a/kotlinx-coroutines-core/jvm/test/TestBase.kt
+++ b/kotlinx-coroutines-core/jvm/test/TestBase.kt
@@ -201,4 +201,10 @@
if (exCount < unhandled.size)
error("Too few unhandled exceptions $exCount, expected ${unhandled.size}")
}
+
+ protected inline fun <reified T: Throwable> assertFailsWith(block: () -> Unit): T {
+ val result = runCatching(block)
+ assertTrue(result.exceptionOrNull() is T, "Expected ${T::class}, but had $result")
+ return result.exceptionOrNull()!! as T
+ }
}
diff --git a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
index 62a340e..5d8c3d5 100644
--- a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
@@ -6,6 +6,7 @@
import org.junit.*
import org.junit.Test
+import java.lang.IllegalStateException
import kotlin.test.*
@Suppress("RedundantAsync")
@@ -22,25 +23,33 @@
@Test
fun testThreadLocal() = runTest {
assertNull(stringThreadLocal.get())
+ assertFalse(stringThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) {
assertEquals("value", stringThreadLocal.get())
+ assertTrue(stringThreadLocal.isPresent())
withContext(executor) {
+ assertTrue(stringThreadLocal.isPresent())
+ assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
assertEquals("value", stringThreadLocal.get())
}
+ assertTrue(stringThreadLocal.isPresent())
assertEquals("value", stringThreadLocal.get())
}
assertNull(stringThreadLocal.get())
deferred.await()
assertNull(stringThreadLocal.get())
+ assertFalse(stringThreadLocal.isPresent())
}
@Test
fun testThreadLocalInitialValue() = runTest {
intThreadLocal.set(42)
+ assertFalse(intThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) {
assertEquals(239, intThreadLocal.get())
withContext(executor) {
+ intThreadLocal.ensurePresent()
assertEquals(239, intThreadLocal.get())
}
assertEquals(239, intThreadLocal.get())
@@ -63,6 +72,8 @@
withContext(executor) {
assertEquals(239, intThreadLocal.get())
assertEquals("pew", stringThreadLocal.get())
+ intThreadLocal.ensurePresent()
+ stringThreadLocal.ensurePresent()
}
assertEquals(239, intThreadLocal.get())
@@ -129,6 +140,7 @@
}
deferred.await()
+ assertFalse(stringThreadLocal.isPresent())
assertEquals("main", stringThreadLocal.get())
}
@@ -212,4 +224,10 @@
assertNotSame(mainThread, Thread.currentThread())
}.await()
}
+
+ @Test
+ fun testMissingThreadLocal() = runTest {
+ assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() }
+ assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
+ }
}
diff --git a/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt b/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt
index 8de958e..1945495 100644
--- a/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt
+++ b/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt
@@ -14,7 +14,7 @@
threadLocal.set("main")
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
- println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
+ println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
yield()
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
}