Unwrap exception on CompletionStage#await slow-path to provide consistent results
Fixes #375
diff --git a/integration/kotlinx-coroutines-guava/src/test/kotlin/kotlinx/coroutines/experimental/guava/ListenableFutureExceptionsTest.kt b/integration/kotlinx-coroutines-guava/src/test/kotlin/kotlinx/coroutines/experimental/guava/ListenableFutureExceptionsTest.kt
new file mode 100644
index 0000000..2646dfa
--- /dev/null
+++ b/integration/kotlinx-coroutines-guava/src/test/kotlin/kotlinx/coroutines/experimental/guava/ListenableFutureExceptionsTest.kt
@@ -0,0 +1,81 @@
+package kotlinx.coroutines.experimental.guava
+
+import com.google.common.util.concurrent.*
+import kotlinx.coroutines.experimental.*
+import org.junit.Test
+import java.io.*
+import java.util.concurrent.*
+import kotlin.coroutines.experimental.*
+import kotlin.test.*
+
+class ListenableFutureExceptionsTest : TestBase() {
+
+ @Test
+ fun testAwait() {
+ testException(IOException(), { it is IOException })
+ }
+
+ @Test
+ fun testAwaitChained() {
+ testException(IOException(), { it is IOException }, { i -> i!! + 1 })
+ }
+
+ @Test
+ fun testAwaitCompletionException() {
+ testException(CompletionException("test", IOException()), { it is CompletionException })
+ }
+
+ @Test
+ fun testAwaitChainedCompletionException() {
+ testException(
+ CompletionException("test", IOException()),
+ { it is CompletionException },
+ { i -> i!! + 1 })
+ }
+
+ @Test
+ fun testAwaitTestException() {
+ testException(TestException(), { it is TestException })
+ }
+
+ @Test
+ fun testAwaitChainedTestException() {
+ testException(TestException(), { it is TestException }, { i -> i!! + 1 })
+ }
+
+ class TestException : CompletionException("test2")
+
+ private fun testException(
+ exception: Exception,
+ expected: ((Throwable) -> Boolean),
+ transformer: ((Int?) -> Int?)? = null
+ ) {
+
+ // Fast path
+ runTest {
+ val future = SettableFuture.create<Int>()
+ val chained = if (transformer == null) future else Futures.transform(future, transformer)
+ future.setException(exception)
+ try {
+ chained.await()
+ } catch (e: Exception) {
+ assertTrue(expected(e))
+ }
+ }
+
+ // Slow path
+ runTest {
+ val future = SettableFuture.create<Int>()
+ val chained = if (transformer == null) future else Futures.transform(future, transformer)
+ launch(coroutineContext) {
+ future.setException(exception)
+ }
+
+ try {
+ chained.await()
+ } catch (e: Exception) {
+ assertTrue(expected(e))
+ }
+ }
+ }
+}
diff --git a/integration/kotlinx-coroutines-jdk8/src/main/kotlin/kotlinx/coroutines/experimental/future/Future.kt b/integration/kotlinx-coroutines-jdk8/src/main/kotlin/kotlinx/coroutines/experimental/future/Future.kt
index d476990..357ac1c 100644
--- a/integration/kotlinx-coroutines-jdk8/src/main/kotlin/kotlinx/coroutines/experimental/future/Future.kt
+++ b/integration/kotlinx-coroutines-jdk8/src/main/kotlin/kotlinx/coroutines/experimental/future/Future.kt
@@ -188,10 +188,13 @@
@Suppress("UNCHECKED_CAST")
override fun accept(result: T?, exception: Throwable?) {
val cont = this.cont ?: return // atomically read current value unless null
- if (exception == null) // the future has been completed normally
+ if (exception == null) {
+ // the future has been completed normally
cont.resume(result as T)
- else // the future has completed with an exception
- cont.resumeWithException(exception)
+ } else {
+ // the future has completed with an exception, unwrap it to provide consistent view of .await() result and to propagate only original exception
+ cont.resumeWithException((exception as? CompletionException)?.cause ?: exception)
+ }
}
}
@@ -214,4 +217,4 @@
public fun <T> future(
context: CoroutineContext = CommonPool,
block: suspend () -> T
-): CompletableFuture<T> = future(context=context) { block() }
+): CompletableFuture<T> = future(context = context) { block() }
diff --git a/integration/kotlinx-coroutines-jdk8/src/test/kotlin/kotlinx/coroutines/experimental/future/FutureExceptionsTest.kt b/integration/kotlinx-coroutines-jdk8/src/test/kotlin/kotlinx/coroutines/experimental/future/FutureExceptionsTest.kt
new file mode 100644
index 0000000..e3a2126
--- /dev/null
+++ b/integration/kotlinx-coroutines-jdk8/src/test/kotlin/kotlinx/coroutines/experimental/future/FutureExceptionsTest.kt
@@ -0,0 +1,86 @@
+package kotlinx.coroutines.experimental.future
+
+import kotlinx.coroutines.experimental.*
+import org.junit.Test
+import java.io.*
+import java.util.concurrent.*
+import kotlin.coroutines.experimental.*
+import kotlin.test.*
+
+class FutureExceptionsTest : TestBase() {
+
+ @Test
+ fun testAwait() {
+ testException(IOException(), { it is IOException })
+ }
+
+ @Test
+ fun testAwaitChained() {
+ testException(IOException(), { it is IOException }, { f -> f.thenApply { it + 1 } })
+ }
+
+ @Test
+ fun testAwaitDeepChain() {
+ testException(IOException(), { it is IOException },
+ { f -> f
+ .thenApply { it + 1 }
+ .thenApply { it + 2 } })
+ }
+
+ @Test
+ fun testAwaitCompletionException() {
+ testException(CompletionException("test", IOException()), { it is IOException })
+ }
+
+ @Test
+ fun testAwaitChainedCompletionException() {
+ testException(CompletionException("test", IOException()), { it is IOException }, { f -> f.thenApply { it + 1 } })
+ }
+
+ @Test
+ fun testAwaitTestException() {
+ testException(TestException(), { it is TestException })
+ }
+
+ @Test
+ fun testAwaitChainedTestException() {
+ testException(TestException(), { it is TestException }, { f -> f.thenApply { it + 1 } })
+ }
+
+ class TestException : CompletionException("test2")
+
+ private fun testException(
+ exception: Exception,
+ expected: ((Throwable) -> Boolean),
+ transformer: (CompletableFuture<Int>) -> CompletableFuture<Int> = { it }
+ ) {
+
+ // Fast path
+ runTest {
+ val future = CompletableFuture<Int>()
+ val chained = transformer(future)
+ future.completeExceptionally(exception)
+ try {
+ chained.await()
+ } catch (e: Exception) {
+ assertTrue(expected(e))
+ }
+ }
+
+ // Slow path
+ runTest {
+ val future = CompletableFuture<Int>()
+ val chained = transformer(future)
+
+ launch(coroutineContext) {
+ future.completeExceptionally(exception)
+ }
+
+ try {
+ chained.await()
+ } catch (e: Exception) {
+ assertTrue(expected(e))
+ }
+ }
+ }
+}