Properly check identity of caught AbortFlowException in Flow.first op… (#2057)
It fixes two problems:
* NoSuchElementException can be thrown during cancellation sequence (see FirstJvmTest that reproduces this problem with explanation)
* Cancellation can be accidentally suppressed and flow activity can be prolonged
Fixes #2051
diff --git a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt
index 98e665f..d99ae52 100644
--- a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt
+++ b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt
@@ -8,9 +8,7 @@
package kotlinx.coroutines.flow
-import kotlinx.coroutines.*
import kotlinx.coroutines.flow.internal.*
-import kotlinx.coroutines.flow.internal.unsafeFlow as flow
import kotlin.jvm.*
/**
@@ -84,15 +82,10 @@
*/
public suspend fun <T> Flow<T>.first(): T {
var result: Any? = NULL
- try {
- collect { value ->
- result = value
- throw AbortFlowException(NopCollector)
- }
- } catch (e: AbortFlowException) {
- // Do nothing
+ collectUntil {
+ result = it
+ true
}
-
if (result === NULL) throw NoSuchElementException("Expected at least one element")
return result as T
}
@@ -103,17 +96,14 @@
*/
public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
var result: Any? = NULL
- try {
- collect { value ->
- if (predicate(value)) {
- result = value
- throw AbortFlowException(NopCollector)
- }
+ collectUntil {
+ if (predicate(it)) {
+ result = it
+ true
+ } else {
+ false
}
- } catch (e: AbortFlowException) {
- // Do nothing
}
-
if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate")
return result as T
}
@@ -124,13 +114,9 @@
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
var result: T? = null
- try {
- collect { value ->
- result = value
- throw AbortFlowException(NopCollector)
- }
- } catch (e: AbortFlowException) {
- // Do nothing
+ collectUntil {
+ result = it
+ true
}
return result
}
@@ -141,15 +127,28 @@
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(predicate: suspend (T) -> Boolean): T? {
var result: T? = null
- try {
- collect { value ->
- if (predicate(value)) {
- result = value
- throw AbortFlowException(NopCollector)
- }
+ collectUntil {
+ if (predicate(it)) {
+ result = it
+ true
+ } else {
+ false
}
- } catch (e: AbortFlowException) {
- // Do nothing
}
return result
}
+
+internal suspend inline fun <T> Flow<T>.collectUntil(crossinline block: suspend (value: T) -> Boolean) {
+ val collector = object : FlowCollector<T> {
+ override suspend fun emit(value: T) {
+ if (block(value)) {
+ throw AbortFlowException(this)
+ }
+ }
+ }
+ try {
+ collect(collector)
+ } catch (e: AbortFlowException) {
+ e.checkOwnership(collector)
+ }
+}
diff --git a/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt b/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt
index f737a1d..edb9f00 100644
--- a/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt
+++ b/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt
@@ -6,6 +6,7 @@
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
+import kotlinx.coroutines.flow.internal.*
import kotlin.test.*
class FirstTest : TestBase() {
@@ -160,4 +161,13 @@
assertSame(instance, flow.first { true })
assertSame(instance, flow.firstOrNull { true })
}
+
+ @Test
+ fun testAbortFlowException() = runTest {
+ val flow = flow<Int> {
+ throw AbortFlowException(NopCollector) // Emulate cancellation
+ }
+
+ assertFailsWith<CancellationException> { flow.first() }
+ }
}
diff --git a/kotlinx-coroutines-core/jvm/test/flow/FirstJvmTest.kt b/kotlinx-coroutines-core/jvm/test/flow/FirstJvmTest.kt
new file mode 100644
index 0000000..77ad083
--- /dev/null
+++ b/kotlinx-coroutines-core/jvm/test/flow/FirstJvmTest.kt
@@ -0,0 +1,28 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
+package kotlinx.coroutines.flow
+
+import kotlinx.coroutines.*
+import org.junit.Test
+import kotlin.test.*
+
+class FirstJvmTest : TestBase() {
+
+ @Test
+ fun testTakeInterference() = runBlocking(Dispatchers.Default) {
+ /*
+ * This test tests a racy situation when outer channelFlow is being cancelled,
+ * inner flow starts atomically in "CANCELLING" state, sends one element and completes
+ * (=> cancels and drops element away), triggering NSEE in Flow.first operator
+ */
+ val values = (0..10000).asFlow().flatMapMerge(Int.MAX_VALUE) {
+ channelFlow {
+ val value = channelFlow { send(1) }.first()
+ send(value)
+ }
+ }.take(1).toList()
+ assertEquals(listOf(1), values)
+ }
+}
\ No newline at end of file