Fixed Publisher/Observable/Flowable.openSubscription in presence of selects;
support an optional `request` parameter in openSubscription to specify
how many elements are requested from publisher in advance on subscription.
Fixes #197
diff --git a/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt b/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt
index bec6916..8be8da5 100644
--- a/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt
+++ b/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt
@@ -27,9 +27,11 @@
/**
* Subscribes to this [Publisher] and returns a channel to receive elements emitted by it.
* The resulting channel shall be [closed][SubscriptionReceiveChannel.close] to unsubscribe from this publisher.
+ * @param request how many items to request from publisher in advance (optional, on-demand request by default).
*/
-public fun <T> Publisher<T>.openSubscription(): SubscriptionReceiveChannel<T> {
- val channel = SubscriptionChannel<T>()
+@JvmOverloads // for binary compatibility
+public fun <T> Publisher<T>.openSubscription(request: Int = 0): SubscriptionReceiveChannel<T> {
+ val channel = SubscriptionChannel<T>(request)
subscribe(channel)
return channel
}
@@ -47,7 +49,6 @@
* This is a shortcut for `open().iterator()`. See [openSubscription] if you need an ability to manually
* unsubscribe from the observable.
*/
-
@Suppress("DeprecatedCallableAddReplaceWith")
@Deprecated(message =
"This iteration operator for `for (x in source) { ... }` loop is deprecated, " +
@@ -58,7 +59,7 @@
/**
* Subscribes to this [Publisher] and performs the specified action for each received element.
*/
-public inline suspend fun <T> Publisher<T>.consumeEach(action: (T) -> Unit) {
+public suspend inline fun <T> Publisher<T>.consumeEach(action: (T) -> Unit) {
openSubscription().use { channel ->
for (x in channel) action(x)
}
@@ -71,36 +72,39 @@
public suspend fun <T> Publisher<T>.consumeEach(action: suspend (T) -> Unit) =
consumeEach { action(it) }
-private class SubscriptionChannel<T> : LinkedListChannel<T>(), SubscriptionReceiveChannel<T>, Subscriber<T> {
- @Volatile
- @JvmField
- var subscription: Subscription? = null
+private class SubscriptionChannel<T>(
+ private val request: Int
+) : LinkedListChannel<T>(), SubscriptionReceiveChannel<T>, Subscriber<T> {
+ init {
+ require(request >= 0) { "Invalid request size: $request" }
+ }
- // request balance from cancelled receivers, balance is negative if we have receivers, but no subscription yet
- val _balance = atomic(0)
+ @Volatile
+ private var subscription: Subscription? = null
+
+ // requested from subscription minus number of received minus number of enqueued receivers,
+ // can be negative if we have receivers, but no subscription yet
+ private val _requested = atomic(0)
// AbstractChannel overrides
- override fun onEnqueuedReceive() {
- _balance.loop { balance ->
+ override fun onReceiveEnqueued() {
+ _requested.loop { wasRequested ->
val subscription = this.subscription
- if (subscription != null) {
- if (balance < 0) { // receivers came before we had subscription
- // try to fixup by making request
- if (!_balance.compareAndSet(balance, 0)) return@loop // continue looping
- subscription.request(-balance.toLong())
- return
- }
- if (balance == 0) { // normal story
- subscription.request(1)
- return
- }
+ val needRequested = wasRequested - 1
+ if (subscription != null && needRequested < 0) { // need to request more from subscription
+ // try to fixup by making request
+ if (wasRequested != request && !_requested.compareAndSet(wasRequested, request))
+ return@loop // continue looping if failed
+ subscription.request((request - needRequested).toLong())
+ return
}
- if (_balance.compareAndSet(balance, balance - 1)) return
+ // just do book-keeping
+ if (_requested.compareAndSet(wasRequested, needRequested)) return
}
}
- override fun onCancelledReceive() {
- _balance.incrementAndGet()
+ override fun onReceiveDequeued() {
+ _requested.incrementAndGet()
}
override fun afterClose(cause: Throwable?) {
@@ -110,22 +114,23 @@
// Subscriber overrides
override fun onSubscribe(s: Subscription) {
subscription = s
- while (true) { // lock-free loop on balance
+ while (true) { // lock-free loop on _requested
if (isClosedForSend) {
s.cancel()
return
}
- val balance = _balance.value
- if (balance >= 0) return // ok -- normal story
- // otherwise, receivers came before we had subscription
+ val wasRequested = _requested.value
+ if (wasRequested >= request) return // ok -- normal story
+ // otherwise, receivers came before we had subscription or need to make initial request
// try to fixup by making request
- if (!_balance.compareAndSet(balance, 0)) continue
- s.request(-balance.toLong())
+ if (!_requested.compareAndSet(wasRequested, request)) continue
+ s.request((request - wasRequested).toLong())
return
}
}
override fun onNext(t: T) {
+ _requested.decrementAndGet()
offer(t)
}
diff --git a/reactive/kotlinx-coroutines-reactive/src/test/kotlin/kotlinx/coroutines/experimental/reactive/PublisherSubscriptionSelectTest.kt b/reactive/kotlinx-coroutines-reactive/src/test/kotlin/kotlinx/coroutines/experimental/reactive/PublisherSubscriptionSelectTest.kt
new file mode 100644
index 0000000..647d662
--- /dev/null
+++ b/reactive/kotlinx-coroutines-reactive/src/test/kotlin/kotlinx/coroutines/experimental/reactive/PublisherSubscriptionSelectTest.kt
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.reactive
+
+import kotlinx.coroutines.experimental.*
+import kotlinx.coroutines.experimental.selects.*
+import org.junit.*
+import org.junit.Assert.*
+import org.junit.runner.*
+import org.junit.runners.*
+
+@RunWith(Parameterized::class)
+class PublisherSubscriptionSelectTest(val request: Int) : TestBase() {
+ companion object {
+ @Parameterized.Parameters(name = "request = {0}")
+ @JvmStatic
+ fun params(): Collection<Array<Any>> = listOf(0, 1, 10).map { arrayOf<Any>(it) }
+ }
+
+ @Test
+ fun testSelect() = runTest {
+ // source with n ints
+ val n = 1000 * stressTestMultiplier
+ val source = publish(coroutineContext) { repeat(n) { send(it) } }
+ var a = 0
+ var b = 0
+ // open two subs
+ source.openSubscription(request).use { channelA ->
+ source.openSubscription(request).use { channelB ->
+ loop@ while (true) {
+ val done: Int = select {
+ channelA.onReceiveOrNull {
+ if (it != null) assertEquals(a++, it)
+ if (it == null) 0 else 1
+ }
+ channelB.onReceiveOrNull {
+ if (it != null) assertEquals(b++, it)
+ if (it == null) 0 else 2
+ }
+ }
+ when (done) {
+ 0 -> break@loop
+ 1 -> {
+ val r = channelB.receiveOrNull()
+ if (r != null) assertEquals(b++, r)
+ }
+ 2 -> {
+ val r = channelA.receiveOrNull()
+ if (r != null) assertEquals(a++, r)
+ }
+ }
+ }
+ }
+ }
+ // should receive one of them fully
+ assertTrue(a == n || b == n)
+ }
+}
\ No newline at end of file
diff --git a/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt b/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt
index 1a6b7c0..ad87de9 100644
--- a/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt
+++ b/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt
@@ -27,9 +27,11 @@
/**
* Subscribes to this [Observable] and returns a channel to receive elements emitted by it.
* The resulting channel shall be [closed][SubscriptionReceiveChannel.close] to unsubscribe from this observable.
+ * @param request how many items to request from publisher in advance (optional, on-demand request by default).
*/
-public fun <T> Observable<T>.openSubscription(): SubscriptionReceiveChannel<T> {
- val channel = SubscriptionChannel<T>()
+@JvmOverloads // for binary compatibility
+public fun <T> Observable<T>.openSubscription(request: Int = 0): SubscriptionReceiveChannel<T> {
+ val channel = SubscriptionChannel<T>(request)
val subscription = subscribe(channel.subscriber)
channel.subscription = subscription
if (channel.isClosedForSend) subscription.unsubscribe()
@@ -59,7 +61,7 @@
/**
* Subscribes to this [Observable] and performs the specified action for each received element.
*/
-public inline suspend fun <T> Observable<T>.consumeEach(action: (T) -> Unit) {
+public suspend inline fun <T> Observable<T>.consumeEach(action: (T) -> Unit) {
openSubscription().use { channel ->
for (x in channel) action(x)
}
@@ -72,45 +74,58 @@
public suspend fun <T> Observable<T>.consumeEach(action: suspend (T) -> Unit) =
consumeEach { action(it) }
-private class SubscriptionChannel<T> : LinkedListChannel<T>(), SubscriptionReceiveChannel<T> {
+private class SubscriptionChannel<T>(
+ private val request: Int
+) : LinkedListChannel<T>(), SubscriptionReceiveChannel<T> {
+ init {
+ require(request >= 0) { "Invalid request size: $request" }
+ }
+
@JvmField
- val subscriber: ChannelSubscriber = ChannelSubscriber()
+ val subscriber: ChannelSubscriber = ChannelSubscriber(request)
@Volatile
@JvmField
var subscription: Subscription? = null
- val _balance = atomic(0) // request balance from cancelled receivers
+ // requested from subscription minus number of received minus number of enqueued receivers,
+ private val _requested = atomic(request)
// AbstractChannel overrides
- override fun onEnqueuedReceive() {
- _balance.loop { balance ->
- if (balance == 0) {
- subscriber.requestOne()
+ override fun onReceiveEnqueued() {
+ _requested.loop { wasRequested ->
+ val needRequested = wasRequested - 1
+ if (needRequested < 0) { // need to request more from subscriber
+ // try to fixup by making request
+ if (wasRequested != request && !_requested.compareAndSet(wasRequested, request))
+ return@loop // continue looping if failed
+ subscriber.makeRequest((request - needRequested).toLong())
return
}
- if (_balance.compareAndSet(balance, balance - 1)) return
+ // just do book-keeping
+ if (_requested.compareAndSet(wasRequested, needRequested)) return
}
}
- override fun onCancelledReceive() {
- _balance.incrementAndGet()
+ override fun onReceiveDequeued() {
+ _requested.incrementAndGet()
}
override fun afterClose(cause: Throwable?) {
subscription?.unsubscribe()
}
- inner class ChannelSubscriber: Subscriber<T>() {
- fun requestOne() {
- request(1)
+ inner class ChannelSubscriber(private val request: Int): Subscriber<T>() {
+ fun makeRequest(n: Long) {
+ request(n)
}
override fun onStart() {
- request(0) // init backpressure, but don't request anything yet
+ request(request.toLong()) // init backpressure
}
override fun onNext(t: T) {
+ _requested.decrementAndGet()
offer(t)
}
diff --git a/reactive/kotlinx-coroutines-rx1/src/test/kotlin/kotlinx/coroutines/experimental/rx1/ObservableSubscriptionSelectTest.kt b/reactive/kotlinx-coroutines-rx1/src/test/kotlin/kotlinx/coroutines/experimental/rx1/ObservableSubscriptionSelectTest.kt
new file mode 100644
index 0000000..49e7abf
--- /dev/null
+++ b/reactive/kotlinx-coroutines-rx1/src/test/kotlin/kotlinx/coroutines/experimental/rx1/ObservableSubscriptionSelectTest.kt
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.rx1
+
+import kotlinx.coroutines.experimental.*
+import kotlinx.coroutines.experimental.selects.*
+import org.junit.*
+import org.junit.Assert.*
+import org.junit.runner.*
+import org.junit.runners.*
+
+@RunWith(Parameterized::class)
+class ObservableSubscriptionSelectTest(val request: Int) : TestBase() {
+ companion object {
+ @Parameterized.Parameters(name = "request = {0}")
+ @JvmStatic
+ fun params(): Collection<Array<Any>> = listOf(0, 1, 10).map { arrayOf<Any>(it) }
+ }
+
+ @Test
+ fun testSelect() = runTest {
+ // source with n ints
+ val n = 1000 * stressTestMultiplier
+ val source = rxObservable(coroutineContext) { repeat(n) { send(it) } }
+ var a = 0
+ var b = 0
+ // open two subs
+ source.openSubscription(request).use { channelA ->
+ source.openSubscription(request).use { channelB ->
+ loop@ while (true) {
+ val done: Int = select {
+ channelA.onReceiveOrNull {
+ if (it != null) assertEquals(a++, it)
+ if (it == null) 0 else 1
+ }
+ channelB.onReceiveOrNull {
+ if (it != null) assertEquals(b++, it)
+ if (it == null) 0 else 2
+ }
+ }
+ when (done) {
+ 0 -> break@loop
+ 1 -> {
+ val r = channelB.receiveOrNull()
+ if (r != null) assertEquals(b++, r)
+ }
+ 2 -> {
+ val r = channelA.receiveOrNull()
+ if (r != null) assertEquals(a++, r)
+ }
+ }
+ }
+ }
+ }
+ // should receive one of them fully
+ assertTrue(a == n || b == n)
+ }
+}
\ No newline at end of file
diff --git a/reactive/kotlinx-coroutines-rx2/src/test/kotlin/kotlinx/coroutines/experimental/rx2/ObservableSubscriptionSelectTest.kt b/reactive/kotlinx-coroutines-rx2/src/test/kotlin/kotlinx/coroutines/experimental/rx2/ObservableSubscriptionSelectTest.kt
new file mode 100644
index 0000000..24c31af
--- /dev/null
+++ b/reactive/kotlinx-coroutines-rx2/src/test/kotlin/kotlinx/coroutines/experimental/rx2/ObservableSubscriptionSelectTest.kt
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.rx2
+
+import kotlinx.coroutines.experimental.*
+import kotlinx.coroutines.experimental.selects.*
+import org.junit.*
+import org.junit.Assert.*
+
+class ObservableSubscriptionSelectTest() : TestBase() {
+ @Test
+ fun testSelect() = runTest {
+ // source with n ints
+ val n = 1000 * stressTestMultiplier
+ val source = rxObservable(coroutineContext) { repeat(n) { send(it) } }
+ var a = 0
+ var b = 0
+ // open two subs
+ source.openSubscription().use { channelA ->
+ source.openSubscription().use { channelB ->
+ loop@ while (true) {
+ val done: Int = select {
+ channelA.onReceiveOrNull {
+ if (it != null) assertEquals(a++, it)
+ if (it == null) 0 else 1
+ }
+ channelB.onReceiveOrNull {
+ if (it != null) assertEquals(b++, it)
+ if (it == null) 0 else 2
+ }
+ }
+ when (done) {
+ 0 -> break@loop
+ 1 -> {
+ val r = channelB.receiveOrNull()
+ if (r != null) assertEquals(b++, r)
+ }
+ 2 -> {
+ val r = channelA.receiveOrNull()
+ if (r != null) assertEquals(a++, r)
+ }
+ }
+ }
+ }
+ }
+ // should receive one of them fully
+ assertTrue(a == n || b == n)
+ }
+}
\ No newline at end of file