Fixed Publisher/Observable/Flowable.openSubscription in presence of selects;
support an optional `request` parameter in openSubscription to specify
how many elements are requested from publisher in advance on subscription.
Fixes #197
diff --git a/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/AbstractChannel.kt b/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/AbstractChannel.kt
index 367a3b2..16862d1 100644
--- a/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/AbstractChannel.kt
+++ b/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/AbstractChannel.kt
@@ -283,7 +283,7 @@
* Retrieves first receiving waiter from the queue or returns closed token.
* @suppress **This is unstable API and it is subject to change.**
*/
- protected fun takeFirstReceiveOrPeekClosed(): ReceiveOrClosed<E>? =
+ protected open fun takeFirstReceiveOrPeekClosed(): ReceiveOrClosed<E>? =
queue.removeFirstIfIsInstanceOfOrPeekIf<ReceiveOrClosed<E>>({ it is Closed<*> })
// ------ registerSelectSend ------
@@ -520,7 +520,7 @@
val result = if (isBufferAlwaysEmpty)
queue.addLastIfPrev(receive, { it !is Send }) else
queue.addLastIfPrevAndIf(receive, { it !is Send }, { isBufferEmpty })
- if (result) onEnqueuedReceive()
+ if (result) onReceiveEnqueued()
return result
}
@@ -640,7 +640,7 @@
override fun finishOnSuccess(affected: LockFreeLinkedListNode, next: LockFreeLinkedListNode) {
super.finishOnSuccess(affected, next)
// notify the there is one more receiver
- onEnqueuedReceive()
+ onReceiveEnqueued()
// we can actually remove on select start, but this is also Ok (it'll get removed if discovered there)
node.removeOnSelectCompletion()
}
@@ -724,29 +724,36 @@
// ------ protected ------
- /**
- * Invoked when receiver is successfully enqueued to the queue of waiting receivers.
- */
- protected open fun onEnqueuedReceive() {}
+ override fun takeFirstReceiveOrPeekClosed(): ReceiveOrClosed<E>? =
+ super.takeFirstReceiveOrPeekClosed().also {
+ if (it != null && it !is Closed<*>) onReceiveDequeued()
+ }
/**
- * Invoked when enqueued receiver was successfully cancelled.
+ * Invoked when receiver is successfully enqueued to the queue of waiting receivers.
+ * @suppress **This is unstable API and it is subject to change.**
*/
- protected open fun onCancelledReceive() {}
+ protected open fun onReceiveEnqueued() {}
+
+ /**
+ * Invoked when enqueued receiver was successfully removed from the queue of waiting receivers.
+ * @suppress **This is unstable API and it is subject to change.**
+ */
+ protected open fun onReceiveDequeued() {}
// ------ private ------
private fun removeReceiveOnCancel(cont: CancellableContinuation<*>, receive: Receive<*>) {
cont.invokeOnCompletion {
if (cont.isCancelled && receive.remove())
- onCancelledReceive()
+ onReceiveDequeued()
}
}
private class Itr<E>(val channel: AbstractChannel<E>) : ChannelIterator<E> {
var result: Any? = POLL_FAILED // E | POLL_FAILED | Closed
- suspend override fun hasNext(): Boolean {
+ override suspend fun hasNext(): Boolean {
// check for repeated hasNext
if (result !== POLL_FAILED) return hasNextResult(result)
// fast path -- try poll non-blocking
@@ -790,7 +797,7 @@
}
@Suppress("UNCHECKED_CAST")
- suspend override fun next(): E {
+ override suspend fun next(): E {
val result = this.result
if (result is Closed<*>) throw result.receiveException
if (result !== POLL_FAILED) {
@@ -889,7 +896,7 @@
override fun dispose() { // invoked on select completion
if (remove())
- onCancelledReceive() // notify cancellation of receive
+ onReceiveDequeued() // notify cancellation of receive
}
override fun toString(): String = "ReceiveSelect[$select,nullOnClose=$nullOnClose]"
diff --git a/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt b/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt
index bec6916..8be8da5 100644
--- a/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt
+++ b/reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt
@@ -27,9 +27,11 @@
/**
* Subscribes to this [Publisher] and returns a channel to receive elements emitted by it.
* The resulting channel shall be [closed][SubscriptionReceiveChannel.close] to unsubscribe from this publisher.
+ * @param request how many items to request from publisher in advance (optional, on-demand request by default).
*/
-public fun <T> Publisher<T>.openSubscription(): SubscriptionReceiveChannel<T> {
- val channel = SubscriptionChannel<T>()
+@JvmOverloads // for binary compatibility
+public fun <T> Publisher<T>.openSubscription(request: Int = 0): SubscriptionReceiveChannel<T> {
+ val channel = SubscriptionChannel<T>(request)
subscribe(channel)
return channel
}
@@ -47,7 +49,6 @@
* This is a shortcut for `open().iterator()`. See [openSubscription] if you need an ability to manually
* unsubscribe from the observable.
*/
-
@Suppress("DeprecatedCallableAddReplaceWith")
@Deprecated(message =
"This iteration operator for `for (x in source) { ... }` loop is deprecated, " +
@@ -58,7 +59,7 @@
/**
* Subscribes to this [Publisher] and performs the specified action for each received element.
*/
-public inline suspend fun <T> Publisher<T>.consumeEach(action: (T) -> Unit) {
+public suspend inline fun <T> Publisher<T>.consumeEach(action: (T) -> Unit) {
openSubscription().use { channel ->
for (x in channel) action(x)
}
@@ -71,36 +72,39 @@
public suspend fun <T> Publisher<T>.consumeEach(action: suspend (T) -> Unit) =
consumeEach { action(it) }
-private class SubscriptionChannel<T> : LinkedListChannel<T>(), SubscriptionReceiveChannel<T>, Subscriber<T> {
- @Volatile
- @JvmField
- var subscription: Subscription? = null
+private class SubscriptionChannel<T>(
+ private val request: Int
+) : LinkedListChannel<T>(), SubscriptionReceiveChannel<T>, Subscriber<T> {
+ init {
+ require(request >= 0) { "Invalid request size: $request" }
+ }
- // request balance from cancelled receivers, balance is negative if we have receivers, but no subscription yet
- val _balance = atomic(0)
+ @Volatile
+ private var subscription: Subscription? = null
+
+ // requested from subscription minus number of received minus number of enqueued receivers,
+ // can be negative if we have receivers, but no subscription yet
+ private val _requested = atomic(0)
// AbstractChannel overrides
- override fun onEnqueuedReceive() {
- _balance.loop { balance ->
+ override fun onReceiveEnqueued() {
+ _requested.loop { wasRequested ->
val subscription = this.subscription
- if (subscription != null) {
- if (balance < 0) { // receivers came before we had subscription
- // try to fixup by making request
- if (!_balance.compareAndSet(balance, 0)) return@loop // continue looping
- subscription.request(-balance.toLong())
- return
- }
- if (balance == 0) { // normal story
- subscription.request(1)
- return
- }
+ val needRequested = wasRequested - 1
+ if (subscription != null && needRequested < 0) { // need to request more from subscription
+ // try to fixup by making request
+ if (wasRequested != request && !_requested.compareAndSet(wasRequested, request))
+ return@loop // continue looping if failed
+ subscription.request((request - needRequested).toLong())
+ return
}
- if (_balance.compareAndSet(balance, balance - 1)) return
+ // just do book-keeping
+ if (_requested.compareAndSet(wasRequested, needRequested)) return
}
}
- override fun onCancelledReceive() {
- _balance.incrementAndGet()
+ override fun onReceiveDequeued() {
+ _requested.incrementAndGet()
}
override fun afterClose(cause: Throwable?) {
@@ -110,22 +114,23 @@
// Subscriber overrides
override fun onSubscribe(s: Subscription) {
subscription = s
- while (true) { // lock-free loop on balance
+ while (true) { // lock-free loop on _requested
if (isClosedForSend) {
s.cancel()
return
}
- val balance = _balance.value
- if (balance >= 0) return // ok -- normal story
- // otherwise, receivers came before we had subscription
+ val wasRequested = _requested.value
+ if (wasRequested >= request) return // ok -- normal story
+ // otherwise, receivers came before we had subscription or need to make initial request
// try to fixup by making request
- if (!_balance.compareAndSet(balance, 0)) continue
- s.request(-balance.toLong())
+ if (!_requested.compareAndSet(wasRequested, request)) continue
+ s.request((request - wasRequested).toLong())
return
}
}
override fun onNext(t: T) {
+ _requested.decrementAndGet()
offer(t)
}
diff --git a/reactive/kotlinx-coroutines-reactive/src/test/kotlin/kotlinx/coroutines/experimental/reactive/PublisherSubscriptionSelectTest.kt b/reactive/kotlinx-coroutines-reactive/src/test/kotlin/kotlinx/coroutines/experimental/reactive/PublisherSubscriptionSelectTest.kt
new file mode 100644
index 0000000..647d662
--- /dev/null
+++ b/reactive/kotlinx-coroutines-reactive/src/test/kotlin/kotlinx/coroutines/experimental/reactive/PublisherSubscriptionSelectTest.kt
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.reactive
+
+import kotlinx.coroutines.experimental.*
+import kotlinx.coroutines.experimental.selects.*
+import org.junit.*
+import org.junit.Assert.*
+import org.junit.runner.*
+import org.junit.runners.*
+
+@RunWith(Parameterized::class)
+class PublisherSubscriptionSelectTest(val request: Int) : TestBase() {
+ companion object {
+ @Parameterized.Parameters(name = "request = {0}")
+ @JvmStatic
+ fun params(): Collection<Array<Any>> = listOf(0, 1, 10).map { arrayOf<Any>(it) }
+ }
+
+ @Test
+ fun testSelect() = runTest {
+ // source with n ints
+ val n = 1000 * stressTestMultiplier
+ val source = publish(coroutineContext) { repeat(n) { send(it) } }
+ var a = 0
+ var b = 0
+ // open two subs
+ source.openSubscription(request).use { channelA ->
+ source.openSubscription(request).use { channelB ->
+ loop@ while (true) {
+ val done: Int = select {
+ channelA.onReceiveOrNull {
+ if (it != null) assertEquals(a++, it)
+ if (it == null) 0 else 1
+ }
+ channelB.onReceiveOrNull {
+ if (it != null) assertEquals(b++, it)
+ if (it == null) 0 else 2
+ }
+ }
+ when (done) {
+ 0 -> break@loop
+ 1 -> {
+ val r = channelB.receiveOrNull()
+ if (r != null) assertEquals(b++, r)
+ }
+ 2 -> {
+ val r = channelA.receiveOrNull()
+ if (r != null) assertEquals(a++, r)
+ }
+ }
+ }
+ }
+ }
+ // should receive one of them fully
+ assertTrue(a == n || b == n)
+ }
+}
\ No newline at end of file
diff --git a/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt b/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt
index 1a6b7c0..ad87de9 100644
--- a/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt
+++ b/reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt
@@ -27,9 +27,11 @@
/**
* Subscribes to this [Observable] and returns a channel to receive elements emitted by it.
* The resulting channel shall be [closed][SubscriptionReceiveChannel.close] to unsubscribe from this observable.
+ * @param request how many items to request from publisher in advance (optional, on-demand request by default).
*/
-public fun <T> Observable<T>.openSubscription(): SubscriptionReceiveChannel<T> {
- val channel = SubscriptionChannel<T>()
+@JvmOverloads // for binary compatibility
+public fun <T> Observable<T>.openSubscription(request: Int = 0): SubscriptionReceiveChannel<T> {
+ val channel = SubscriptionChannel<T>(request)
val subscription = subscribe(channel.subscriber)
channel.subscription = subscription
if (channel.isClosedForSend) subscription.unsubscribe()
@@ -59,7 +61,7 @@
/**
* Subscribes to this [Observable] and performs the specified action for each received element.
*/
-public inline suspend fun <T> Observable<T>.consumeEach(action: (T) -> Unit) {
+public suspend inline fun <T> Observable<T>.consumeEach(action: (T) -> Unit) {
openSubscription().use { channel ->
for (x in channel) action(x)
}
@@ -72,45 +74,58 @@
public suspend fun <T> Observable<T>.consumeEach(action: suspend (T) -> Unit) =
consumeEach { action(it) }
-private class SubscriptionChannel<T> : LinkedListChannel<T>(), SubscriptionReceiveChannel<T> {
+private class SubscriptionChannel<T>(
+ private val request: Int
+) : LinkedListChannel<T>(), SubscriptionReceiveChannel<T> {
+ init {
+ require(request >= 0) { "Invalid request size: $request" }
+ }
+
@JvmField
- val subscriber: ChannelSubscriber = ChannelSubscriber()
+ val subscriber: ChannelSubscriber = ChannelSubscriber(request)
@Volatile
@JvmField
var subscription: Subscription? = null
- val _balance = atomic(0) // request balance from cancelled receivers
+ // requested from subscription minus number of received minus number of enqueued receivers,
+ private val _requested = atomic(request)
// AbstractChannel overrides
- override fun onEnqueuedReceive() {
- _balance.loop { balance ->
- if (balance == 0) {
- subscriber.requestOne()
+ override fun onReceiveEnqueued() {
+ _requested.loop { wasRequested ->
+ val needRequested = wasRequested - 1
+ if (needRequested < 0) { // need to request more from subscriber
+ // try to fixup by making request
+ if (wasRequested != request && !_requested.compareAndSet(wasRequested, request))
+ return@loop // continue looping if failed
+ subscriber.makeRequest((request - needRequested).toLong())
return
}
- if (_balance.compareAndSet(balance, balance - 1)) return
+ // just do book-keeping
+ if (_requested.compareAndSet(wasRequested, needRequested)) return
}
}
- override fun onCancelledReceive() {
- _balance.incrementAndGet()
+ override fun onReceiveDequeued() {
+ _requested.incrementAndGet()
}
override fun afterClose(cause: Throwable?) {
subscription?.unsubscribe()
}
- inner class ChannelSubscriber: Subscriber<T>() {
- fun requestOne() {
- request(1)
+ inner class ChannelSubscriber(private val request: Int): Subscriber<T>() {
+ fun makeRequest(n: Long) {
+ request(n)
}
override fun onStart() {
- request(0) // init backpressure, but don't request anything yet
+ request(request.toLong()) // init backpressure
}
override fun onNext(t: T) {
+ _requested.decrementAndGet()
offer(t)
}
diff --git a/reactive/kotlinx-coroutines-rx1/src/test/kotlin/kotlinx/coroutines/experimental/rx1/ObservableSubscriptionSelectTest.kt b/reactive/kotlinx-coroutines-rx1/src/test/kotlin/kotlinx/coroutines/experimental/rx1/ObservableSubscriptionSelectTest.kt
new file mode 100644
index 0000000..49e7abf
--- /dev/null
+++ b/reactive/kotlinx-coroutines-rx1/src/test/kotlin/kotlinx/coroutines/experimental/rx1/ObservableSubscriptionSelectTest.kt
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.rx1
+
+import kotlinx.coroutines.experimental.*
+import kotlinx.coroutines.experimental.selects.*
+import org.junit.*
+import org.junit.Assert.*
+import org.junit.runner.*
+import org.junit.runners.*
+
+@RunWith(Parameterized::class)
+class ObservableSubscriptionSelectTest(val request: Int) : TestBase() {
+ companion object {
+ @Parameterized.Parameters(name = "request = {0}")
+ @JvmStatic
+ fun params(): Collection<Array<Any>> = listOf(0, 1, 10).map { arrayOf<Any>(it) }
+ }
+
+ @Test
+ fun testSelect() = runTest {
+ // source with n ints
+ val n = 1000 * stressTestMultiplier
+ val source = rxObservable(coroutineContext) { repeat(n) { send(it) } }
+ var a = 0
+ var b = 0
+ // open two subs
+ source.openSubscription(request).use { channelA ->
+ source.openSubscription(request).use { channelB ->
+ loop@ while (true) {
+ val done: Int = select {
+ channelA.onReceiveOrNull {
+ if (it != null) assertEquals(a++, it)
+ if (it == null) 0 else 1
+ }
+ channelB.onReceiveOrNull {
+ if (it != null) assertEquals(b++, it)
+ if (it == null) 0 else 2
+ }
+ }
+ when (done) {
+ 0 -> break@loop
+ 1 -> {
+ val r = channelB.receiveOrNull()
+ if (r != null) assertEquals(b++, r)
+ }
+ 2 -> {
+ val r = channelA.receiveOrNull()
+ if (r != null) assertEquals(a++, r)
+ }
+ }
+ }
+ }
+ }
+ // should receive one of them fully
+ assertTrue(a == n || b == n)
+ }
+}
\ No newline at end of file
diff --git a/reactive/kotlinx-coroutines-rx2/src/test/kotlin/kotlinx/coroutines/experimental/rx2/ObservableSubscriptionSelectTest.kt b/reactive/kotlinx-coroutines-rx2/src/test/kotlin/kotlinx/coroutines/experimental/rx2/ObservableSubscriptionSelectTest.kt
new file mode 100644
index 0000000..24c31af
--- /dev/null
+++ b/reactive/kotlinx-coroutines-rx2/src/test/kotlin/kotlinx/coroutines/experimental/rx2/ObservableSubscriptionSelectTest.kt
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.rx2
+
+import kotlinx.coroutines.experimental.*
+import kotlinx.coroutines.experimental.selects.*
+import org.junit.*
+import org.junit.Assert.*
+
+class ObservableSubscriptionSelectTest() : TestBase() {
+ @Test
+ fun testSelect() = runTest {
+ // source with n ints
+ val n = 1000 * stressTestMultiplier
+ val source = rxObservable(coroutineContext) { repeat(n) { send(it) } }
+ var a = 0
+ var b = 0
+ // open two subs
+ source.openSubscription().use { channelA ->
+ source.openSubscription().use { channelB ->
+ loop@ while (true) {
+ val done: Int = select {
+ channelA.onReceiveOrNull {
+ if (it != null) assertEquals(a++, it)
+ if (it == null) 0 else 1
+ }
+ channelB.onReceiveOrNull {
+ if (it != null) assertEquals(b++, it)
+ if (it == null) 0 else 2
+ }
+ }
+ when (done) {
+ 0 -> break@loop
+ 1 -> {
+ val r = channelB.receiveOrNull()
+ if (r != null) assertEquals(b++, r)
+ }
+ 2 -> {
+ val r = channelA.receiveOrNull()
+ if (r != null) assertEquals(a++, r)
+ }
+ }
+ }
+ }
+ }
+ // should receive one of them fully
+ assertTrue(a == n || b == n)
+ }
+}
\ No newline at end of file