Fix withTimeout/OrNull bug with spurious concurrency on cancellation
diff --git a/coroutines-guide.md b/coroutines-guide.md
index 0c7bae7..469a335 100644
--- a/coroutines-guide.md
+++ b/coroutines-guide.md
@@ -542,12 +542,13 @@
I'm sleeping 0 ...
I'm sleeping 1 ...
I'm sleeping 2 ...
-Exception in thread "main" java.util.concurrent.CancellationException: Timed out waiting for 1300 MILLISECONDS
+Exception in thread "main" kotlinx.coroutines.experimental.TimeoutException: Timed out waiting for 1300 MILLISECONDS
```
<!--- TEST STARTS_WITH -->
-We have not seen the [CancellationException] stack trace printed on the console before. That is because
+The `TimeoutException` that is thrown by [withTimeout] is a private subclass of [CancellationException].
+We have not seen its stack trace printed on the console before. That is because
inside a cancelled coroutine `CancellationException` is considered to be a normal reason for coroutine completion.
However, in this example we have used `withTimeout` right inside the `main` function.
diff --git a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CancellableContinuation.kt b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CancellableContinuation.kt
index 2a56b45..cfabad1 100644
--- a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CancellableContinuation.kt
+++ b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CancellableContinuation.kt
@@ -241,12 +241,7 @@
when (mode) {
MODE_DISPATCHED -> delegate.resumeWithException(exception)
MODE_UNDISPATCHED -> (delegate as DispatchedContinuation).resumeUndispatchedWithException(exception)
- MODE_DIRECT -> {
- if (delegate is DispatchedContinuation)
- delegate.continuation.resumeWithException(exception)
- else
- delegate.resumeWithException(exception)
- }
+ MODE_DIRECT -> delegate.resumeDirectWithException(exception)
else -> error("Invalid mode $mode")
}
} else {
@@ -254,12 +249,7 @@
when (mode) {
MODE_DISPATCHED -> delegate.resume(value)
MODE_UNDISPATCHED -> (delegate as DispatchedContinuation).resumeUndispatched(value)
- MODE_DIRECT -> {
- if (delegate is DispatchedContinuation)
- delegate.continuation.resume(value)
- else
- delegate.resume(value)
- }
+ MODE_DIRECT -> delegate.resumeDirect(value)
else -> error("Invalid mode $mode")
}
}
diff --git a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CoroutineDispatcher.kt b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CoroutineDispatcher.kt
index 04f77d8..e81d4ca 100644
--- a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CoroutineDispatcher.kt
+++ b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/CoroutineDispatcher.kt
@@ -153,3 +153,13 @@
})
}
}
+
+internal fun <T> Continuation<T>.resumeDirect(value: T) = when (this) {
+ is DispatchedContinuation -> continuation.resume(value)
+ else -> resume(value)
+}
+
+internal fun <T> Continuation<T>.resumeDirectWithException(exception: Throwable) = when (this) {
+ is DispatchedContinuation -> continuation.resumeWithException(exception)
+ else -> resumeWithException(exception)
+}
diff --git a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Job.kt b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Job.kt
index 912ff34..c6bf0bc 100644
--- a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Job.kt
+++ b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Job.kt
@@ -353,7 +353,10 @@
}
internal open fun onParentCompletion(cause: Throwable?) {
- cancel()
+ // if parent was completed with CancellationException then use it as the cause of our cancellation, too.
+ // however, we shall not use application specific exceptions here. So if parent crashes due to IOException,
+ // we cannot and should not cancel the child with IOException
+ cancel(cause as? CancellationException)
}
/**
diff --git a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Scheduled.kt b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Scheduled.kt
index 8083ce3..c58c233 100644
--- a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Scheduled.kt
+++ b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/Scheduled.kt
@@ -16,7 +16,6 @@
package kotlinx.coroutines.experimental
-import kotlinx.coroutines.experimental.intrinsics.startCoroutineUndispatched
import kotlinx.coroutines.experimental.selects.SelectBuilder
import kotlinx.coroutines.experimental.selects.select
import java.util.concurrent.ScheduledExecutorService
@@ -24,6 +23,8 @@
import java.util.concurrent.TimeUnit
import kotlin.coroutines.experimental.Continuation
import kotlin.coroutines.experimental.ContinuationInterceptor
+import kotlin.coroutines.experimental.CoroutineContext
+import kotlin.coroutines.experimental.intrinsics.startCoroutineUninterceptedOrReturn
import kotlin.coroutines.experimental.intrinsics.suspendCoroutineOrReturn
private val KEEP_ALIVE = java.lang.Long.getLong("kotlinx.coroutines.ScheduledExecutor.keepAlive", 1000L)
@@ -68,9 +69,10 @@
* Runs a given suspending [block] of code inside a coroutine with a specified timeout and throws
* [CancellationException] if timeout was exceeded.
*
- * The code that is executing inside the [block] is cancelled on timeout and throws [CancellationException]
- * exception inside of it, too. However, even the code in the block suppresses the exception,
- * this `withTimeout` function invocation still throws [CancellationException].
+ * The code that is executing inside the [block] is cancelled on timeout and the active or next invocation of
+ * cancellable suspending function inside the block throws [CancellationException], so normally that exception,
+ * if uncaught, also gets thrown by `withTimeout` as a result.
+ * However, the code in the block can suppresses [CancellationException].
*
* The sibling function that does not throw exception on timeout is [withTimeoutOrNull].
* Note, that timeout action can be specified for [select] invocation with [onTimeout][SelectBuilder.onTimeout] clause.
@@ -84,27 +86,40 @@
public suspend fun <T> withTimeout(time: Long, unit: TimeUnit = TimeUnit.MILLISECONDS, block: suspend () -> T): T {
require(time >= 0) { "Timeout time $time cannot be negative" }
if (time <= 0L) throw CancellationException("Timed out immediately")
- return suspendCoroutineOrReturn sc@ { delegate: Continuation<T> ->
- // schedule cancellation of this continuation on time
- val cont = TimeoutExceptionContinuation(time, unit, delegate)
- val delay = cont.context[ContinuationInterceptor] as? Delay
+ return suspendCoroutineOrReturn { cont: Continuation<T> ->
+ val context = cont.context
+ val coroutine = TimeoutExceptionCoroutine(time, unit, cont)
+ val delay = context[ContinuationInterceptor] as? Delay
+ // schedule cancellation of this coroutine on time
if (delay != null)
- cont.disposeOnCompletion(delay.invokeOnTimeout(time, unit, cont)) else
- cont.cancelFutureOnCompletion(scheduledExecutor.schedule(cont, time, unit))
- // restart block using cancellable context of this continuation,
+ coroutine.disposeOnCompletion(delay.invokeOnTimeout(time, unit, coroutine)) else
+ coroutine.cancelFutureOnCompletion(scheduledExecutor.schedule(coroutine, time, unit))
+ coroutine.initParentJob(context[Job])
+ // restart block using new coroutine with new job,
// however start it as undispatched coroutine, because we are already in the proper context
- block.startCoroutineUndispatched(cont)
- cont.getResult()
+ block.startCoroutineUninterceptedOrReturn(coroutine)
}
}
+private class TimeoutExceptionCoroutine<in T>(
+ private val time: Long,
+ private val unit: TimeUnit,
+ private val cont: Continuation<T>
+) : JobSupport(active = true), Runnable, Continuation<T> {
+ override val context: CoroutineContext = cont.context + this // mix in this Job into the context
+ override fun run() { cancel(TimeoutException(time, unit)) }
+ override fun resume(value: T) { cont.resumeDirect(value) }
+ override fun resumeWithException(exception: Throwable) { cont.resumeDirectWithException(exception) }
+}
+
/**
* Runs a given suspending block of code inside a coroutine with a specified timeout and returns
* `null` if timeout was exceeded.
*
- * The code that is executing inside the [block] is cancelled on timeout and throws [CancellationException]
- * exception inside of it. However, even the code in the block does not catch the cancellation exception,
- * this `withTimeoutOrNull` function invocation still returns `null` on timeout.
+ * The code that is executing inside the [block] is cancelled on timeout and the active or next invocation of
+ * cancellable suspending function inside the block throws [CancellationException]. Normally that exception,
+ * if uncaught by the block, gets converted into the `null` result of `withTimeoutOrNull`.
+ * However, the code in the block can suppresses [CancellationException].
*
* The sibling function that throws exception on timeout is [withTimeout].
* Note, that timeout action can be specified for [select] invocation with [onTimeout][SelectBuilder.onTimeout] clause.
@@ -118,33 +133,39 @@
public suspend fun <T> withTimeoutOrNull(time: Long, unit: TimeUnit = TimeUnit.MILLISECONDS, block: suspend () -> T): T? {
require(time >= 0) { "Timeout time $time cannot be negative" }
if (time <= 0L) return null
- return suspendCoroutineOrReturn sc@ { delegate: Continuation<T?> ->
- // schedule cancellation of this continuation on time
- val cont = TimeoutNullContinuation<T>(delegate)
- val delay = cont.context[ContinuationInterceptor] as? Delay
+ return suspendCoroutineOrReturn { cont: Continuation<T?> ->
+ val context = cont.context
+ val coroutine = TimeoutNullCoroutine(time, unit, cont)
+ val delay = context[ContinuationInterceptor] as? Delay
+ // schedule cancellation of this coroutine on time
if (delay != null)
- cont.disposeOnCompletion(delay.invokeOnTimeout(time, unit, cont)) else
- cont.cancelFutureOnCompletion(scheduledExecutor.schedule(cont, time, unit))
- // restart block using cancellable context of this continuation,
+ coroutine.disposeOnCompletion(delay.invokeOnTimeout(time, unit, coroutine)) else
+ coroutine.cancelFutureOnCompletion(scheduledExecutor.schedule(coroutine, time, unit))
+ coroutine.initParentJob(context[Job])
+ // restart block using new coroutine with new job,
// however start it as undispatched coroutine, because we are already in the proper context
- block.startCoroutineUndispatched(cont)
- cont.getResult()
+ try {
+ block.startCoroutineUninterceptedOrReturn(coroutine)
+ } catch (e: TimeoutException) {
+ null // replace inner timeout exception with null result
+ }
}
}
-private class TimeoutExceptionContinuation<in T>(
+private class TimeoutNullCoroutine<in T>(
private val time: Long,
private val unit: TimeUnit,
- delegate: Continuation<T>
-) : CancellableContinuationImpl<T>(delegate, active = true), Runnable {
- override val defaultResumeMode get() = MODE_DIRECT
- override fun run() { cancel(CancellationException("Timed out waiting for $time $unit")) }
+ private val cont: Continuation<T?>
+) : JobSupport(active = true), Runnable, Continuation<T> {
+ override val context: CoroutineContext = cont.context + this // mix in this Job into the context
+ override fun run() { cancel(TimeoutException(time, unit)) }
+ override fun resume(value: T) { cont.resumeDirect(value) }
+ override fun resumeWithException(exception: Throwable) {
+ // suppress inner timeout exception and replace it with null
+ if (exception is TimeoutException)
+ cont.resumeDirect(null) else
+ cont.resumeDirectWithException(exception)
+ }
}
-private class TimeoutNullContinuation<in T>(
- delegate: Continuation<T?>
-) : CancellableContinuationImpl<T?>(delegate, active = true), Runnable {
- override val defaultResumeMode get() = MODE_DIRECT
- override val ignoreRepeatedResume: Boolean get() = true
- override fun run() { resume(null, mode = 0) /* dispatch resume */ }
-}
+private class TimeoutException(time: Long, unit: TimeUnit) : CancellationException("Timed out waiting for $time $unit")
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/src/test/kotlin/guide/test/GuideTest.kt b/kotlinx-coroutines-core/src/test/kotlin/guide/test/GuideTest.kt
index 5e4ac5c..ee95ad2 100644
--- a/kotlinx-coroutines-core/src/test/kotlin/guide/test/GuideTest.kt
+++ b/kotlinx-coroutines-core/src/test/kotlin/guide/test/GuideTest.kt
@@ -120,7 +120,7 @@
"I'm sleeping 0 ...",
"I'm sleeping 1 ...",
"I'm sleeping 2 ...",
- "Exception in thread \"main\" java.util.concurrent.CancellationException: Timed out waiting for 1300 MILLISECONDS"
+ "Exception in thread \"main\" kotlinx.coroutines.experimental.TimeoutException: Timed out waiting for 1300 MILLISECONDS"
)
}
diff --git a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullTest.kt b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullTest.kt
index 453fd67..84109ac 100644
--- a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullTest.kt
+++ b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullTest.kt
@@ -20,6 +20,7 @@
import org.hamcrest.core.IsNull
import org.junit.Assert.assertThat
import org.junit.Test
+import java.io.IOException
class WithTimeoutOrNullTest : TestBase() {
/**
@@ -47,6 +48,51 @@
finish(8)
}
+ @Test
+ fun testNullOnTimeout() = runBlocking {
+ expect(1)
+ val result = withTimeoutOrNull(100) {
+ expect(2)
+ delay(1000)
+ expectUnreached()
+ "OK"
+ }
+ assertThat(result, IsNull())
+ finish(3)
+ }
+
+ @Test
+ fun testSuppressException() = runBlocking {
+ expect(1)
+ val result = withTimeoutOrNull(100) {
+ expect(2)
+ try {
+ delay(1000)
+ } catch (e: CancellationException) {
+ expect(3)
+ }
+ "OK"
+ }
+ assertThat(result, IsEqual("OK"))
+ finish(4)
+ }
+
+ @Test(expected = IOException::class)
+ fun testReplaceException() = runBlocking {
+ expect(1)
+ withTimeoutOrNull(100) {
+ expect(2)
+ try {
+ delay(1000)
+ } catch (e: CancellationException) {
+ finish(3)
+ throw IOException(e)
+ }
+ "OK"
+ }
+ expectUnreached()
+ }
+
/**
* Tests that a 100% CPU-consuming loop will react on timeout if it has yields.
*/
diff --git a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullThreadDispatchTest.kt b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullThreadDispatchTest.kt
index d6d726f..e7a1f77 100644
--- a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullThreadDispatchTest.kt
+++ b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutOrNullThreadDispatchTest.kt
@@ -20,11 +20,11 @@
import org.hamcrest.core.IsNull
import org.junit.After
import org.junit.Assert
-import org.junit.Assert.assertThat
import org.junit.Test
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
+import java.util.concurrent.atomic.AtomicInteger
import kotlin.coroutines.experimental.CoroutineContext
class WithTimeoutOrNullThreadDispatchTest : TestBase() {
@@ -51,16 +51,25 @@
}
}
+
@Test
fun testCancellationDispatchCustomNoDelay() {
+ // it also checks that there is at most once scheduled request in flight (no spurious concurrency)
+ var error: String? = null
checkCancellationDispatch {
executor = Executors.newSingleThreadExecutor(it)
+ val scheduled = AtomicInteger(0)
object : CoroutineDispatcher() {
override fun dispatch(context: CoroutineContext, block: Runnable) {
- executor!!.execute(block)
+ if (scheduled.incrementAndGet() > 1) error = "Two requests are scheduled concurrently"
+ executor!!.execute {
+ scheduled.decrementAndGet()
+ block.run()
+ }
}
}
}
+ error?.let { error(it) }
}
private fun checkCancellationDispatch(factory: (ThreadFactory) -> CoroutineDispatcher) = runBlocking {
@@ -70,22 +79,21 @@
run(dispatcher) {
expect(2)
Assert.assertThat(Thread.currentThread(), IsEqual(thread))
- val result =
- withTimeoutOrNull(100) {
- try {
- expect(3)
- delay(1000)
- expectUnreached()
- } catch (e: CancellationException) {
- expect(4)
- Assert.assertThat(Thread.currentThread(), IsEqual(thread))
- }
- expect(5)
- "FAIL"
+ val result = withTimeoutOrNull(100) {
+ try {
+ expect(3)
+ delay(1000)
+ expectUnreached()
+ } catch (e: CancellationException) {
+ expect(4)
+ Assert.assertThat(Thread.currentThread(), IsEqual(thread))
+ throw e // rethrow
}
- assertThat(result, IsNull())
- expect(6)
+ }
+ Assert.assertThat(Thread.currentThread(), IsEqual(thread))
+ Assert.assertThat(result, IsNull())
+ expect(5)
}
- finish(7)
+ finish(6)
}
}
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutTest.kt b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutTest.kt
index 3ae3930..785570c 100644
--- a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutTest.kt
+++ b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutTest.kt
@@ -19,10 +19,11 @@
import org.hamcrest.core.IsEqual
import org.junit.Assert.assertThat
import org.junit.Test
+import java.io.IOException
class WithTimeoutTest : TestBase() {
/**
- * Tests property dispatching of `withTimeout` blocks
+ * Tests proper dispatching of `withTimeout` blocks
*/
@Test
fun testDispatch() = runBlocking {
@@ -46,6 +47,55 @@
finish(8)
}
+
+ @Test
+ fun testExceptionOnTimeout() = runBlocking<Unit> {
+ expect(1)
+ try {
+ withTimeout(100) {
+ expect(2)
+ delay(1000)
+ expectUnreached()
+ "OK"
+ }
+ } catch (e: CancellationException) {
+ assertThat(e.message, IsEqual("Timed out waiting for 100 MILLISECONDS"))
+ finish(3)
+ }
+ }
+
+ @Test
+ fun testSuppressException() = runBlocking {
+ expect(1)
+ val result = withTimeout(100) {
+ expect(2)
+ try {
+ delay(1000)
+ } catch (e: CancellationException) {
+ expect(3)
+ }
+ "OK"
+ }
+ assertThat(result, IsEqual("OK"))
+ finish(4)
+ }
+
+ @Test(expected = IOException::class)
+ fun testReplaceException() = runBlocking {
+ expect(1)
+ withTimeout(100) {
+ expect(2)
+ try {
+ delay(1000)
+ } catch (e: CancellationException) {
+ finish(3)
+ throw IOException(e)
+ }
+ "OK"
+ }
+ expectUnreached()
+ }
+
/**
* Tests that a 100% CPU-consuming loop will react on timeout if it has yields.
*/
diff --git a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutThreadDispatchTest.kt b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutThreadDispatchTest.kt
index f1bfc78..cb065af 100644
--- a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutThreadDispatchTest.kt
+++ b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithTimeoutThreadDispatchTest.kt
@@ -23,6 +23,7 @@
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
+import java.util.concurrent.atomic.AtomicInteger
import kotlin.coroutines.experimental.CoroutineContext
class WithTimeoutThreadDispatchTest : TestBase() {
@@ -51,14 +52,22 @@
@Test
fun testCancellationDispatchCustomNoDelay() {
+ // it also checks that there is at most once scheduled request in flight (no spurious concurrency)
+ var error: String? = null
checkCancellationDispatch {
executor = Executors.newSingleThreadExecutor(it)
+ val scheduled = AtomicInteger(0)
object : CoroutineDispatcher() {
override fun dispatch(context: CoroutineContext, block: Runnable) {
- executor!!.execute(block)
+ if (scheduled.incrementAndGet() > 1) error = "Two requests are scheduled concurrently"
+ executor!!.execute {
+ scheduled.decrementAndGet()
+ block.run()
+ }
}
}
}
+ error?.let { error(it) }
}
private fun checkCancellationDispatch(factory: (ThreadFactory) -> CoroutineDispatcher) = runBlocking {
@@ -77,15 +86,15 @@
} catch (e: CancellationException) {
expect(4)
Assert.assertThat(Thread.currentThread(), IsEqual(thread))
+ throw e // rethrow
}
- expect(5)
}
} catch (e: CancellationException) {
- expect(6)
+ expect(5)
Assert.assertThat(Thread.currentThread(), IsEqual(thread))
}
- expect(7)
+ expect(6)
}
- finish(8)
+ finish(7)
}
}
\ No newline at end of file