Fix race in Flow.asPublisher (#2124)

The race was leading to emitting more items via onNext than requested, the corresponding stress-test was added, too

Fixes #2109
diff --git a/reactive/kotlinx-coroutines-jdk9/test/FlowAsPublisherTest.kt b/reactive/kotlinx-coroutines-jdk9/test/FlowAsPublisherTest.kt
index 8017ee5..488695d 100644
--- a/reactive/kotlinx-coroutines-jdk9/test/FlowAsPublisherTest.kt
+++ b/reactive/kotlinx-coroutines-jdk9/test/FlowAsPublisherTest.kt
@@ -16,10 +16,10 @@
     fun testErrorOnCancellationIsReported() {
         expect(1)
         flow<Int> {
-            emit(2)
             try {
-                hang { expect(3) }
+                emit(2)
             } finally {
+                expect(3)
                 throw TestException()
             }
         }.asPublisher().subscribe(object : JFlow.Subscriber<Int> {
@@ -52,12 +52,11 @@
         expect(1)
         flow<Int>    {
             emit(2)
-            hang { expect(3) }
         }.asPublisher().subscribe(object : JFlow.Subscriber<Int> {
             private lateinit var subscription: JFlow.Subscription
 
             override fun onComplete() {
-                expect(4)
+                expect(3)
             }
 
             override fun onSubscribe(s: JFlow.Subscription?) {
@@ -74,6 +73,6 @@
                 expectUnreached()
             }
         })
-        finish(5)
+        finish(4)
     }
 }
diff --git a/reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt b/reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt
index 96ae628..efa9c9c 100644
--- a/reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt
+++ b/reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt
@@ -166,11 +166,12 @@
 public class FlowSubscription<T>(
     @JvmField public val flow: Flow<T>,
     @JvmField public val subscriber: Subscriber<in T>
-) : Subscription, AbstractCoroutine<Unit>(Dispatchers.Unconfined, false) {
+) : Subscription, AbstractCoroutine<Unit>(Dispatchers.Unconfined, true) {
     private val requested = atomic(0L)
-    private val producer = atomic<CancellableContinuation<Unit>?>(null)
+    private val producer = atomic<Continuation<Unit>?>(createInitialContinuation())
 
-    override fun onStart() {
+    // This code wraps startCoroutineCancellable into continuation
+    private fun createInitialContinuation(): Continuation<Unit> = Continuation(coroutineContext) {
         ::flowProcessing.startCoroutineCancellable(this)
     }
 
@@ -197,19 +198,17 @@
      */
     private suspend fun consumeFlow() {
         flow.collect { value ->
-            /*
-             * Flow is scopeless, thus if it's not active, its subscription was cancelled.
-             * No intermediate "child failed, but flow coroutine is not" states are allowed.
-             */
-            coroutineContext.ensureActive()
-            if (requested.value <= 0L) {
+            // Emit the value
+            subscriber.onNext(value)
+            // Suspend if needed before requesting the next value
+            if (requested.decrementAndGet() <= 0) {
                 suspendCancellableCoroutine<Unit> {
                     producer.value = it
-                    if (requested.value != 0L) it.resumeSafely()
                 }
+            } else {
+                // check for cancellation if we don't suspend
+                coroutineContext.ensureActive()
             }
-            requested.decrementAndGet()
-            subscriber.onNext(value)
         }
     }
 
@@ -218,22 +217,19 @@
     }
 
     override fun request(n: Long) {
-        if (n <= 0) {
-            return
-        }
-        start()
-        requested.update { value ->
+        if (n <= 0) return
+        val old = requested.getAndUpdate { value ->
             val newValue = value + n
             if (newValue <= 0L) Long.MAX_VALUE else newValue
         }
-        val producer = producer.getAndSet(null) ?: return
-        producer.resumeSafely()
-    }
-
-    private fun CancellableContinuation<Unit>.resumeSafely() {
-        val token = tryResume(Unit)
-        if (token != null) {
-            completeResume(token)
+        if (old <= 0L) {
+            assert(old == 0L)
+            // Emitter is not started yet or has suspended -- spin on race with suspendCancellableCoroutine
+            while(true) {
+                val producer = producer.getAndSet(null) ?: continue // spin if not set yet
+                producer.resume(Unit)
+                break
+            }
         }
     }
 }
diff --git a/reactive/kotlinx-coroutines-reactive/test/FlowAsPublisherTest.kt b/reactive/kotlinx-coroutines-reactive/test/FlowAsPublisherTest.kt
index 8633492..c044d92 100644
--- a/reactive/kotlinx-coroutines-reactive/test/FlowAsPublisherTest.kt
+++ b/reactive/kotlinx-coroutines-reactive/test/FlowAsPublisherTest.kt
@@ -16,10 +16,10 @@
     fun testErrorOnCancellationIsReported() {
         expect(1)
         flow<Int> {
-            emit(2)
             try {
-                hang { expect(3) }
+                emit(2)
             } finally {
+                expect(3)
                 throw TestException()
             }
         }.asPublisher().subscribe(object : Subscriber<Int> {
@@ -52,12 +52,11 @@
         expect(1)
         flow<Int>    {
             emit(2)
-            hang { expect(3) }
         }.asPublisher().subscribe(object : Subscriber<Int> {
             private lateinit var subscription: Subscription
 
             override fun onComplete() {
-                expect(4)
+                expect(3)
             }
 
             override fun onSubscribe(s: Subscription?) {
@@ -74,6 +73,6 @@
                 expectUnreached()
             }
         })
-        finish(5)
+        finish(4)
     }
 }
diff --git a/reactive/kotlinx-coroutines-reactive/test/PublisherRequestStressTest.kt b/reactive/kotlinx-coroutines-reactive/test/PublisherRequestStressTest.kt
new file mode 100644
index 0000000..736a664
--- /dev/null
+++ b/reactive/kotlinx-coroutines-reactive/test/PublisherRequestStressTest.kt
@@ -0,0 +1,141 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
+package kotlinx.coroutines.reactive
+
+import kotlinx.coroutines.*
+import kotlinx.coroutines.flow.*
+import kotlinx.coroutines.flow.Flow
+import org.junit.*
+import org.reactivestreams.*
+import java.util.concurrent.*
+import java.util.concurrent.atomic.*
+import kotlin.coroutines.*
+import kotlin.random.*
+
+/**
+ * This stress-test is self-contained reproducer for the race in [Flow.asPublisher] extension
+ * that was originally reported in the issue
+ * [#2109](https://github.com/Kotlin/kotlinx.coroutines/issues/2109).
+ * The original reproducer used a flow that loads a file using AsynchronousFileChannel
+ * (that issues completion callbacks from multiple threads)
+ * and uploads it to S3 via Amazon SDK, which internally uses netty for I/O
+ * (which uses a single thread for connection-related callbacks).
+ *
+ * This stress-test essentially mimics the logic in multiple interacting threads: several emitter threads that form
+ * the flow and a single requesting thread works on the subscriber's side to periodically request more
+ * values when the number of items requested drops below the threshold.
+ */
+@Suppress("ReactiveStreamsSubscriberImplementation")
+class PublisherRequestStressTest : TestBase() {
+    private val testDurationSec = 3 * stressTestMultiplier
+
+    // Original code in Amazon SDK uses 4 and 16 as low/high watermarks.
+    // There constants were chosen so that problem reproduces asap with particular this code.
+    private val minDemand = 8L
+    private val maxDemand = 16L
+    
+    private val nEmitThreads = 4
+
+    private val emitThreadNo = AtomicInteger()
+
+    private val emitPool = Executors.newFixedThreadPool(nEmitThreads) { r ->
+        Thread(r, "PublisherRequestStressTest-emit-${emitThreadNo.incrementAndGet()}")
+    }
+
+    private val reqPool = Executors.newSingleThreadExecutor { r ->
+        Thread(r, "PublisherRequestStressTest-req")
+    }
+    
+    private val nextValue = AtomicLong(0)
+
+    @After
+    fun tearDown() {
+        emitPool.shutdown()
+        reqPool.shutdown()
+        emitPool.awaitTermination(10, TimeUnit.SECONDS)
+        reqPool.awaitTermination(10, TimeUnit.SECONDS)
+    }
+
+    private lateinit var subscription: Subscription
+
+    @Test
+    fun testRequestStress() {
+        val expectedValue = AtomicLong(0)
+        val requestedTill = AtomicLong(0)
+        val completionLatch = CountDownLatch(1)
+        val callingOnNext = AtomicInteger()
+
+        val publisher = mtFlow().asPublisher()
+        var error = false
+        
+        publisher.subscribe(object : Subscriber<Long> {
+            private var demand = 0L // only updated from reqPool
+
+            override fun onComplete() {
+                completionLatch.countDown()
+            }
+
+            override fun onSubscribe(sub: Subscription) {
+                subscription = sub
+                maybeRequestMore()
+            }
+
+            private fun maybeRequestMore() {
+                if (demand >= minDemand) return
+                val nextDemand = Random.nextLong(minDemand + 1..maxDemand)
+                val more = nextDemand - demand
+                demand = nextDemand
+                requestedTill.addAndGet(more)
+                subscription.request(more)
+            }
+
+            override fun onNext(value: Long) {
+                check(callingOnNext.getAndIncrement() == 0) // make sure it is not concurrent
+                // check for expected value
+                check(value == expectedValue.get())
+                // check that it does not exceed requested values
+                check(value < requestedTill.get())
+                val nextExpected = value + 1
+                expectedValue.set(nextExpected)
+                // send more requests from request thread
+                reqPool.execute {
+                    demand-- // processed an item
+                    maybeRequestMore()
+                }
+                callingOnNext.decrementAndGet()
+            }
+
+            override fun onError(ex: Throwable?) {
+                error = true
+                error("Failed", ex)
+            }
+        })
+        var prevExpected = -1L
+        for (second in 1..testDurationSec) {
+            if (error) break
+            Thread.sleep(1000)
+            val expected = expectedValue.get()
+            println("$second: expectedValue = $expected")
+            check(expected > prevExpected) // should have progress
+            prevExpected = expected
+        }
+        if (!error) {
+            subscription.cancel()
+            completionLatch.await()
+        }
+    }
+
+    private fun mtFlow(): Flow<Long> = flow {
+        while (currentCoroutineContext().isActive) {
+            emit(aWait())
+        }
+    }
+
+    private suspend fun aWait(): Long = suspendCancellableCoroutine { cont ->
+        emitPool.execute(Runnable {
+            cont.resume(nextValue.getAndIncrement())
+        })
+    }
+}
\ No newline at end of file