Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods (#1043)

Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods

Fixes #1028
diff --git a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt
index 2705acc..21c473a 100644
--- a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt
+++ b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt
@@ -481,6 +481,8 @@
 public final class kotlinx/coroutines/ThreadContextElementKt {
 	public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
 	public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
+	public static final fun ensurePresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
+	public static final fun isPresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
 }
 
 public final class kotlinx/coroutines/ThreadPoolDispatcherKt {
diff --git a/docs/coroutine-context-and-dispatchers.md b/docs/coroutine-context-and-dispatchers.md
index 00b0db9..29da4b4 100644
--- a/docs/coroutine-context-and-dispatchers.md
+++ b/docs/coroutine-context-and-dispatchers.md
@@ -635,7 +635,7 @@
     threadLocal.set("main")
     println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
     val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
-       println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
+        println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
         yield()
         println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
     }
@@ -664,6 +664,10 @@
 
 <!--- TEST FLEXIBLE_THREAD -->
 
+Note how easily one may forget the corresponding context element and then still safely access thread local.
+To avoid such situations, it is recommended to use [ensurePresent] method
+and fail-fast on improper usages.
+
 `ThreadLocal` has first-class support and can be used with any primitive `kotlinx.coroutines` provides.
 It has one key limitation: when thread-local is mutated, a new value is not propagated to the coroutine caller 
 (as context element cannot track all `ThreadLocal` object accesses) and updated value is lost on the next suspension.
@@ -701,5 +705,6 @@
 [MainScope()]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-main-scope.html
 [Dispatchers.Main]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-dispatchers/-main.html
 [asContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/as-context-element.html
+[ensurePresent]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/ensure-present.html
 [ThreadContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-thread-context-element/index.html
 <!--- END -->
diff --git a/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt b/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt
index bfb5cfd..cf82318 100644
--- a/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt
+++ b/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt
@@ -11,9 +11,7 @@
 import org.junit.*
 import org.junit.Assert.*
 import org.junit.Test
-import java.io.*
 import java.util.concurrent.*
-import kotlin.test.assertFailsWith
 
 class ListenableFutureTest : TestBase() {
     @Before
diff --git a/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt b/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt
index 7038363..7d128c6 100644
--- a/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt
+++ b/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt
@@ -16,7 +16,6 @@
 import kotlin.concurrent.*
 import kotlin.coroutines.*
 import kotlin.reflect.*
-import kotlin.test.assertFailsWith
 
 class FutureTest : TestBase() {
     @Before
diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
index c68ee45..4e8b6cc 100644
--- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
+++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
@@ -135,3 +135,40 @@
  */
 public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
     ThreadLocalElement(value, this)
+
+/**
+ * Return `true` when current thread local is present in the coroutine context, `false` otherwise.
+ * Thread local can be present in the context only if it was added via [asContextElement] to the context.
+ *
+ * Example of usage:
+ * ```
+ * suspend fun processRequest() {
+ *   if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing
+ *      // Do some heavy-weight tracing
+ *   }
+ *   // Process request regularly
+ * }
+ * ```
+ */
+public suspend inline fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] !== null
+
+/**
+ * Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not.
+ * It is a good practice to validate that thread local is present in the context, especially in large code-bases,
+ * to avoid stale thread-local values and to have a strict invariants.
+ *
+ * E.g. one may use the following method to enforce proper use of the thread locals with coroutines:
+ * ```
+ * public suspend inline fun <T> ThreadLocal<T>.getSafely(): T {
+ *   ensurePresent()
+ *   return get()
+ * }
+ *
+ * // Usage
+ * withContext(...) {
+ *   val value = threadLocal.getSafely() // Fail-fast in case of improper context
+ * }
+ * ```
+ */
+public suspend inline fun ThreadLocal<*>.ensurePresent(): Unit =
+    check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" }
diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
index 7dafb47..375dc60 100644
--- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
+++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
@@ -98,7 +98,8 @@
 }
 
 // top-level data class for a nicer out-of-the-box toString representation and class name
-private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
+@PublishedApi
+internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
 
 internal class ThreadLocalElement<T>(
     private val value: T,
diff --git a/kotlinx-coroutines-core/jvm/test/TestBase.kt b/kotlinx-coroutines-core/jvm/test/TestBase.kt
index db5c53a..6fef760 100644
--- a/kotlinx-coroutines-core/jvm/test/TestBase.kt
+++ b/kotlinx-coroutines-core/jvm/test/TestBase.kt
@@ -201,4 +201,10 @@
         if (exCount < unhandled.size)
             error("Too few unhandled exceptions $exCount, expected ${unhandled.size}")
     }
+
+    protected inline fun <reified T: Throwable> assertFailsWith(block: () -> Unit): T {
+        val result = runCatching(block)
+        assertTrue(result.exceptionOrNull() is T, "Expected ${T::class}, but had $result")
+        return result.exceptionOrNull()!! as T
+    }
 }
diff --git a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
index 62a340e..5d8c3d5 100644
--- a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
@@ -6,6 +6,7 @@
 
 import org.junit.*
 import org.junit.Test
+import java.lang.IllegalStateException
 import kotlin.test.*
 
 @Suppress("RedundantAsync")
@@ -22,25 +23,33 @@
     @Test
     fun testThreadLocal() = runTest {
         assertNull(stringThreadLocal.get())
+        assertFalse(stringThreadLocal.isPresent())
         val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) {
             assertEquals("value", stringThreadLocal.get())
+            assertTrue(stringThreadLocal.isPresent())
             withContext(executor) {
+                assertTrue(stringThreadLocal.isPresent())
+                assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
                 assertEquals("value", stringThreadLocal.get())
             }
+            assertTrue(stringThreadLocal.isPresent())
             assertEquals("value", stringThreadLocal.get())
         }
 
         assertNull(stringThreadLocal.get())
         deferred.await()
         assertNull(stringThreadLocal.get())
+        assertFalse(stringThreadLocal.isPresent())
     }
 
     @Test
     fun testThreadLocalInitialValue() = runTest {
         intThreadLocal.set(42)
+        assertFalse(intThreadLocal.isPresent())
         val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) {
             assertEquals(239, intThreadLocal.get())
             withContext(executor) {
+                intThreadLocal.ensurePresent()
                 assertEquals(239, intThreadLocal.get())
             }
             assertEquals(239, intThreadLocal.get())
@@ -63,6 +72,8 @@
             withContext(executor) {
                 assertEquals(239, intThreadLocal.get())
                 assertEquals("pew", stringThreadLocal.get())
+                intThreadLocal.ensurePresent()
+                stringThreadLocal.ensurePresent()
             }
 
             assertEquals(239, intThreadLocal.get())
@@ -129,6 +140,7 @@
         }
 
         deferred.await()
+        assertFalse(stringThreadLocal.isPresent())
         assertEquals("main", stringThreadLocal.get())
     }
 
@@ -212,4 +224,10 @@
             assertNotSame(mainThread, Thread.currentThread())
         }.await()
     }
+
+    @Test
+    fun testMissingThreadLocal() = runTest {
+        assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() }
+        assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
+    }
 }
diff --git a/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt b/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt
index 8de958e..1945495 100644
--- a/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt
+++ b/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt
@@ -14,7 +14,7 @@
     threadLocal.set("main")
     println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
     val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
-       println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
+        println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
         yield()
         println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
     }