Introduce cancelling state for AbstractContinuation, improve exception handling, make tests stricter
diff --git a/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/AbstractContinuation.kt b/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/AbstractContinuation.kt
index c76e304..c8e4c11 100644
--- a/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/AbstractContinuation.kt
+++ b/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/AbstractContinuation.kt
@@ -34,15 +34,23 @@
 ) : Continuation<T>, DispatchedTask<T> {
 
     /*
+     * Implementation notes
+     *
      * AbstractContinuation is a subset of Job with following limitations:
      * 1) It can have only cancellation listeners
      * 2) It always invokes cancellation listener if it's cancelled (no 'invokeImmediately')
      * 3) It can have at most one cancellation listener
-     * 4) It cannot be in cancelling state, only active/finished/cancelled
-     * 5) Its cancellation listeners cannot be deregistered
-     *
+     * 4) Its cancellation listeners cannot be deregistered
      * As a consequence it has much simpler state machine, more lightweight machinery and
      * less dependencies.
+     *
+     * Cancelling state
+     * If useCancellingState is true, then this continuation can have additional cancelling state,
+     * which is transition from Active to Cancelled. This is specific state to support withContext(ctx)
+     * construction: block in withContext can be cancelled from withing or even before stepping into withContext,
+     * but we still want to properly run it (e.g. when it has atomic cancellation mode) and run its completion listener
+     * after.
+     * During cancellation all pending exceptions are aggregated and thrown during transition to final state
      */
 
     /* decision state machine
@@ -67,6 +75,7 @@
        ------      ------------         ------------    -----------
        ACTIVE      Active               : Active        active, no listeners
        SINGLE_A    CancellationHandler  : Active        active, one cancellation listener
+       CANCELLING  Cancelling           : Active        in the process of cancellation due to cancellation of parent job
        CANCELLED   Cancelled            : Cancelled     cancelled (final state)
        COMPLETED   any                  : Completed     produced some result or threw an exception (final state)
      */
@@ -83,6 +92,8 @@
 
     public val isCancelled: Boolean get() = state is CancelledContinuation
 
+    protected open val useCancellingState: Boolean get() = false
+
     internal fun initParentJobInternal(parent: Job?) {
         check(parentHandle == null)
         if (parent == null) {
@@ -105,6 +116,7 @@
     public fun cancel(cause: Throwable?): Boolean {
         loopOnState { state ->
             if (state !is NotCompleted) return false // quit if already complete
+            if (state is Cancelling) return false // someone else succeeded
             if (updateStateCancelled(state, cause)) return true
         }
     }
@@ -174,8 +186,15 @@
         }
     }
 
-    private fun updateStateCancelled(state: NotCompleted, cause: Throwable?): Boolean =
-        updateState(state, CancelledContinuation(this, cause), mode = MODE_ATOMIC_DEFAULT)
+    private fun updateStateCancelled(state: NotCompleted, cause: Throwable?): Boolean {
+        val update: Any = if (useCancellingState) {
+            Cancelling(CancelledContinuation(this, cause))
+        } else {
+            CancelledContinuation(this, cause)
+        }
+
+        return updateState(state, update, mode = MODE_ATOMIC_DEFAULT)
+    }
 
     private fun onCompletionInternal(mode: Int) {
         if (tryResume()) return // completed before getResult invocation -- bail out
@@ -192,50 +211,102 @@
     protected fun resumeImpl(proposedUpdate: Any?, resumeMode: Int) {
         loopOnState { state ->
             when (state) {
+                is Cancelling -> { // withContext() support
+                    /*
+                     * If already cancelled block is resumed with non-exception,
+                     * resume it with cancellation exception.
+                     * E.g.
+                     * ```
+                     * val value = withContext(ctx) {
+                     *   outerJob.cancel() // -> cancelling
+                     *   42 // -> cancelled
+                     * }
+                     * ```
+                     * should throw cancellation exception instead of returning 42
+                     */
+                    if (proposedUpdate !is CompletedExceptionally) {
+                        val update = state.cancel
+                        if (updateState(state, update, resumeMode)) return
+                    } else {
+                        /*
+                         * If already cancelled block is resumed with an exception,
+                         * then we should properly merge them to avoid information loss
+                         */
+                        val update: CompletedExceptionally
+
+                        /*
+                         * Proposed update is another CancellationException.
+                         * e.g.
+                         * ```
+                         * T1: ctxJob.cancel(e1) // -> cancelling
+                         * T2:
+                         * withContext(ctx) {
+                         *   // -> resumed with cancellation exception
+                         * }
+                         * ```
+                         */
+                        if (proposedUpdate.exception is CancellationException) {
+                            // Keep original cancellation cause and try add to suppressed exception from proposed cancel
+                            update = state.cancel
+                            coerceWithCancellation(state, proposedUpdate, update)
+                        } else {
+                            /*
+                             * Proposed update is exception => transition to terminal state
+                             * E.g.
+                             * ```
+                             * withContext(ctx) {
+                             *   outerJob.cancel() // -> cancelling
+                             *   throw Exception() // -> completed exceptionally
+                             * }
+                             * ```
+                             */
+                            val exception = proposedUpdate.exception
+                            // TODO clashes with await all
+                            val currentException = state.cancel.cause
+                            // Add to suppressed if original cancellation differs from proposed exception
+                            if (currentException != null && (currentException !is CancellationException || currentException.cause !== exception)) {
+                                exception.addSuppressedThrowable(currentException)
+                            }
+
+                            update = CompletedExceptionally(exception)
+                        }
+
+                        if (updateState(state, update, resumeMode)) {
+                            return
+                        }
+                    }
+                }
+
                 is NotCompleted -> {
                     if (updateState(state, proposedUpdate, resumeMode)) return
                 }
-
                 is CancelledContinuation -> {
-                    if (proposedUpdate !is CompletedExceptionally) {
-                        return // Cancelled continuation completed, do nothing
+                    if (proposedUpdate is NotCompleted || proposedUpdate is CompletedExceptionally) {
+                        error("Unexpected update, state: $state, update: $proposedUpdate")
                     }
-
-                    /*
-                     * Coerce concurrent cancellation and pending thrown exception.
-                     * E.g. for linear history `T1: cancel() T2 (job): throw e T3: job.await()`
-                     * we'd like to see actual exception in T3, not JobCancellationException.
-                     * So thrown exception overwrites cancellation exception, but
-                     * suppresses its non-null cause.
-                     */
-                    if (state.exception is CancellationException && state.exception.cause == null) {
-                        return // Do not add to suppressed regular cancellation
-                    }
-
-                    if (state.exception is CancellationException && state.exception.cause === proposedUpdate.exception) {
-                        return // Do not add to suppressed cancellation with the same cause
-                    }
-
-                    if (state.exception === proposedUpdate.exception) {
-                        return // Do no add to suppressed the same exception
-                    }
-
-                    val exception = proposedUpdate.exception
-                    val update = CompletedExceptionally(exception)
-                    if (state.cause != null) {
-                        exception.addSuppressedThrowable(state.cause)
-                    }
-
-                    if (_state.compareAndSet(state, update)) {
-                        return
-                    }
+                    // Coroutine is dispatched normally (e.g.via `delay()`) after cancellation
+                    return
                 }
-
                 else -> error("Already resumed, but proposed with update $proposedUpdate")
             }
         }
     }
 
+    // Coerce current cancelling state with proposed cancellation
+    private fun coerceWithCancellation(state: Cancelling, proposedUpdate: CompletedExceptionally, update: CompletedExceptionally) {
+        val originalCancellation = state.cancel
+        val originalException = originalCancellation.exception
+        val updateCause = proposedUpdate.cause
+
+        // Cause of proposed update is present and differs from one in current state TODO clashes with await all
+        val isSameCancellation = originalCancellation.exception is CancellationException
+                && originalException.cause === updateCause?.cause
+
+        if (!isSameCancellation && updateCause !== null && originalException.cause !== updateCause) {
+            update.exception.addSuppressedThrowable(updateCause)
+        }
+    }
+
     private fun makeHandler(handler: CompletionHandler): CancellationHandlerImpl<*> {
         if (handler is CancellationHandlerImpl<*>) {
             require(handler.continuation === this) { "Handler has non-matching continuation ${handler.continuation}, current: $this" }
@@ -250,18 +321,17 @@
     }
 
     private fun updateState(expect: NotCompleted, proposedUpdate: Any?, mode: Int): Boolean {
+        // TODO slow path for cancelling?
         if (!tryUpdateState(expect, proposedUpdate)) {
             return false
         }
 
-        completeUpdateState(expect, proposedUpdate, mode)
+        if (proposedUpdate !is Cancelling) {
+            completeUpdateState(expect, proposedUpdate, mode)
+        }
         return true
     }
 
-    /**
-     * Completes update of the current [state] of this job.
-     * @suppress **This is unstable API and it is subject to change.**
-     */
     protected fun completeUpdateState(expect: NotCompleted, update: Any?, mode: Int) {
         val exceptionally = update as? CompletedExceptionally
         onCompletionInternal(mode)
@@ -276,17 +346,16 @@
         }
     }
 
-    /**
-     * Tries to initiate update of the current [state] of this job.
-     * @suppress **This is unstable API and it is subject to change.**
-     */
     protected fun tryUpdateState(expect: NotCompleted, update: Any?): Boolean {
-        require(update !is NotCompleted) // only NotCompleted -> completed transition is allowed
         if (!_state.compareAndSet(expect, update)) return false
-        // Unregister from parent job
-        parentHandle?.let {
-            it.dispose() // volatile read parentHandle _after_ state was updated
-            parentHandle = NonDisposableHandle // release it just in case, to aid GC
+
+        // TODO separate code path?
+        if (update !is Cancelling) {
+            // Unregister from parent job
+            parentHandle?.let {
+                it.dispose() // volatile read parentHandle _after_ state was updated
+                parentHandle = NonDisposableHandle // release it just in case, to aid GC
+            }
         }
         return true // continues in completeUpdateState
     }
@@ -313,9 +382,11 @@
 internal interface NotCompleted
 
 private class Active : NotCompleted
-
 private val ACTIVE: Active = Active()
 
+// In progress of cancellation
+internal class Cancelling(@JvmField val cancel: CancelledContinuation) : NotCompleted
+
 internal abstract class CancellationHandlerImpl<out C : AbstractContinuation<*>>(@JvmField val continuation: C) :
     CancellationHandler(), NotCompleted
 
diff --git a/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/Builders.common.kt b/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/Builders.common.kt
index 5c6d1f2..55f8cad 100644
--- a/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/Builders.common.kt
+++ b/common/kotlinx-coroutines-core-common/src/main/kotlin/kotlinx/coroutines/experimental/Builders.common.kt
@@ -194,4 +194,7 @@
     override val context: CoroutineContext,
     delegate: Continuation<T>,
     resumeMode: Int
-) : AbstractContinuation<T>(delegate, resumeMode)
+) : AbstractContinuation<T>(delegate, resumeMode) {
+
+    override val useCancellingState: Boolean get() = true
+}
diff --git a/common/kotlinx-coroutines-core-common/src/test/kotlin/kotlinx/coroutines/experimental/WithContextTest.kt b/common/kotlinx-coroutines-core-common/src/test/kotlin/kotlinx/coroutines/experimental/WithContextTest.kt
index f59aa6b..bb5af79 100644
--- a/common/kotlinx-coroutines-core-common/src/test/kotlin/kotlinx/coroutines/experimental/WithContextTest.kt
+++ b/common/kotlinx-coroutines-core-common/src/test/kotlin/kotlinx/coroutines/experimental/WithContextTest.kt
@@ -118,18 +118,22 @@
     }
 
     @Test
-    fun testRunAtomicTryCancel() = runTest(
-        expected = { it is JobCancellationException }
-    ) {
+    fun testRunAtomicTryCancel() = runTest {
         expect(1)
         val job = Job()
         job.cancel() // try to cancel before it has a chance to run
-        withContext(job + wrapperDispatcher(coroutineContext), CoroutineStart.ATOMIC) { // but start atomically
-            // TODO here behaviour changed
-            require(!isActive)
-            finish(2)
-            yield() // but will cancel here
-            expectUnreached()
+
+        try {
+            withContext(job + wrapperDispatcher(coroutineContext), CoroutineStart.ATOMIC) {
+                require(isActive)
+                // but start atomically
+                expect(2)
+                yield() // but will cancel here
+                expectUnreached()
+            }
+        } catch (e: JobCancellationException) {
+            // This block should be invoked *after* context body
+            finish(3)
         }
     }
 
@@ -158,12 +162,14 @@
                 withContext(wrapperDispatcher(coroutineContext)) {
                     require(isActive)
                     expect(5)
-                    job!!.cancel() // cancel itself
+                    require(job!!.cancel()) // cancel itself
+                    require(!job!!.cancel(AssertionError())) // cancel again, no success here
+                    require(!isActive)
                     throw TestException() // but throw a different exception
                 }
             } catch (e: Throwable) {
                 expect(7)
-                // make sure TestException, not CancellationException is thrown!
+                // make sure TestException, not CancellationException or AssertionError is thrown
                 assertTrue(e is TestException, "Caught $e")
             }
         }
@@ -184,12 +190,15 @@
             try {
                 expect(3)
                 withContext(wrapperDispatcher(coroutineContext)) {
+                    require(isActive)
                     expect(5)
-                    job!!.cancel() // cancel itself
+                    require(job!!.cancel()) // cancel itself
+                    require(!job!!.cancel(AssertionError())) // cancel again, no success here
+                    require(!isActive)
                 }
             } catch (e: Throwable) {
                 expect(7)
-                // make sure TestException, not CancellationException is thrown!
+                // make sure TestException, not CancellationException or AssertionError is thrown!
                 assertTrue(e is JobCancellationException, "Caught $e")
             }
         }
diff --git a/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/JoinStressTest.kt b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/JoinStressTest.kt
index d58f27f..665d0ce 100644
--- a/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/JoinStressTest.kt
+++ b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/JoinStressTest.kt
@@ -8,8 +8,7 @@
 
 class JoinStressTest : TestBase() {
 
-    val iterations = 50_000 * stressTestMultiplier
-
+    private val iterations = 50_000 * stressTestMultiplier
     private val pool = newFixedThreadPoolContext(3, "JoinStressTest")
 
     @After
diff --git a/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithContextCancellationStressTest.kt b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithContextCancellationStressTest.kt
new file mode 100644
index 0000000..e028fbf
--- /dev/null
+++ b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/WithContextCancellationStressTest.kt
@@ -0,0 +1,80 @@
+package kotlinx.coroutines.experimental
+
+import org.junit.*
+import java.io.*
+import java.util.concurrent.*
+import kotlin.coroutines.experimental.*
+
+class WithContextCancellationStressTest : TestBase() {
+
+    private val iterations = 150_000 * stressTestMultiplier
+    private val pool = newFixedThreadPoolContext(3, "WithContextCancellationStressTest")
+
+    @After
+    fun tearDown() {
+        pool.close()
+    }
+
+    @Test
+    fun testConcurrentCancellation() = runBlocking {
+        var ioException = 0
+        var arithmeticException = 0
+        var aiobException = 0
+
+        repeat(iterations) {
+            val barrier = CyclicBarrier(4)
+            val jobWithContext = async(pool) {
+                barrier.await()
+                withContext(wrapperDispatcher(coroutineContext)) {
+                    throw IOException()
+                }
+            }
+
+            val cancellerJob = async(pool) {
+                barrier.await()
+                jobWithContext.cancel(ArithmeticException())
+            }
+
+            val cancellerJob2 = async(pool) {
+                barrier.await()
+                jobWithContext.cancel(ArrayIndexOutOfBoundsException())
+            }
+
+            barrier.await()
+            val c1 = cancellerJob.await()
+            val c2 = cancellerJob2.await()
+            require(!(c1 && c2)) { "Same job cannot be cancelled twice" }
+
+            try {
+                jobWithContext.await()
+            } catch (e: Exception) {
+                when (e) {
+                    is IOException -> ++ioException
+                    is JobCancellationException -> {
+                        val cause = e.cause
+                        when (cause) {
+                            is ArithmeticException -> ++arithmeticException
+                            is ArrayIndexOutOfBoundsException -> ++aiobException
+                            else -> error("Unexpected exception")
+                        }
+                    }
+                    else -> error("Unexpected exception")
+                }
+            }
+        }
+
+        // Backward compatibility, no exceptional code paths were lost
+        require(ioException > 0) { "At least one IOException expected" }
+        require(arithmeticException > 0) { "At least one ArithmeticException expected" }
+        require(aiobException > 0) { "At least one ArrayIndexOutOfBoundsException expected" }
+    }
+
+    private fun wrapperDispatcher(context: CoroutineContext): CoroutineContext {
+        val dispatcher = context[ContinuationInterceptor] as CoroutineDispatcher
+        return object : CoroutineDispatcher() {
+            override fun dispatch(context: CoroutineContext, block: Runnable) {
+                dispatcher.dispatch(context, block)
+            }
+        }
+    }
+}
diff --git a/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/channels/ProduceTest.kt b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/channels/ProduceTest.kt
index 641eff0..6fa7f76 100644
--- a/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/channels/ProduceTest.kt
+++ b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/channels/ProduceTest.kt
@@ -18,6 +18,7 @@
 
 import kotlinx.coroutines.experimental.*
 import org.junit.*
+import java.io.*
 import kotlin.coroutines.experimental.*
 
 class ProduceTest : TestBase() {