Allocate underlying buffer in ArrayChannel in on-demand manner (#1388)
* Allocate underlying buffer in ArrayChannel in on-demand manner
Rationale:
Such change will allow us to use huge buffers in various flow operators without having a serious footprint in suspension-free scenarios
diff --git a/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt b/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt
index 688125d..1e1c0d3 100644
--- a/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt
+++ b/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt
@@ -8,6 +8,7 @@
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.selects.*
import kotlin.jvm.*
+import kotlin.math.*
/**
* Channel with array buffer of a fixed [capacity].
@@ -29,10 +30,14 @@
}
private val lock = ReentrantLock()
- private val buffer: Array<Any?> = arrayOfNulls<Any?>(capacity)
+ /*
+ * Guarded by lock.
+ * Allocate minimum of capacity and 16 to avoid excess memory pressure for large channels when it's not necessary.
+ */
+ private var buffer: Array<Any?> = arrayOfNulls<Any?>(min(capacity, 8))
private var head: Int = 0
@Volatile
- private var size: Int = 0
+ private var size: Int = 0 // Invariant: size <= capacity
protected final override val isBufferAlwaysEmpty: Boolean get() = false
protected final override val isBufferEmpty: Boolean get() = size == 0
@@ -64,7 +69,8 @@
}
}
}
- buffer[(head + size) % capacity] = element // actually queue element
+ ensureCapacity(size)
+ buffer[(head + size) % buffer.size] = element // actually queue element
return OFFER_SUCCESS
}
// size == capacity: full
@@ -112,7 +118,8 @@
this.size = size // restore size
return ALREADY_SELECTED
}
- buffer[(head + size) % capacity] = element // actually queue element
+ ensureCapacity(size)
+ buffer[(head + size) % buffer.size] = element // actually queue element
return OFFER_SUCCESS
}
// size == capacity: full
@@ -123,6 +130,19 @@
return receive!!.offerResult
}
+ // Guarded by lock
+ private fun ensureCapacity(currentSize: Int) {
+ if (currentSize >= buffer.size) {
+ val newSize = min(buffer.size * 2, capacity)
+ val newBuffer = arrayOfNulls<Any?>(newSize)
+ for (i in 0 until currentSize) {
+ newBuffer[i] = buffer[(head + i) % buffer.size]
+ }
+ buffer = newBuffer
+ head = 0
+ }
+ }
+
// result is `E | POLL_FAILED | Closed`
protected override fun pollInternal(): Any? {
var send: Send? = null
@@ -149,9 +169,9 @@
}
if (replacement !== POLL_FAILED && replacement !is Closed<*>) {
this.size = size // restore size
- buffer[(head + size) % capacity] = replacement
+ buffer[(head + size) % buffer.size] = replacement
}
- head = (head + 1) % capacity
+ head = (head + 1) % buffer.size
}
// complete send the we're taken replacement from
if (token != null)
@@ -203,7 +223,7 @@
}
if (replacement !== POLL_FAILED && replacement !is Closed<*>) {
this.size = size // restore size
- buffer[(head + size) % capacity] = replacement
+ buffer[(head + size) % buffer.size] = replacement
} else {
// failed to poll or is already closed --> let's try to select receiving this element from buffer
if (!select.trySelect(null)) { // :todo: move trySelect completion outside of lock
@@ -212,7 +232,7 @@
return ALREADY_SELECTED
}
}
- head = (head + 1) % capacity
+ head = (head + 1) % buffer.size
}
// complete send the we're taken replacement from
if (token != null)
@@ -226,7 +246,7 @@
lock.withLock {
repeat(size) {
buffer[head] = 0
- head = (head + 1) % capacity
+ head = (head + 1) % buffer.size
}
size = 0
}
@@ -237,5 +257,5 @@
// ------ debug ------
override val bufferDebugString: String
- get() = "(buffer:capacity=${buffer.size},size=$size)"
+ get() = "(buffer:capacity=$capacity,size=$size)"
}
diff --git a/kotlinx-coroutines-core/common/test/channels/ArrayChannelTest.kt b/kotlinx-coroutines-core/common/test/channels/ArrayChannelTest.kt
index 2b948df..ceef21e 100644
--- a/kotlinx-coroutines-core/common/test/channels/ArrayChannelTest.kt
+++ b/kotlinx-coroutines-core/common/test/channels/ArrayChannelTest.kt
@@ -86,7 +86,7 @@
}
@Test
- fun testOfferAndPool() = runTest {
+ fun testOfferAndPoll() = runTest {
val q = Channel<Int>(1)
assertTrue(q.offer(1))
expect(1)
@@ -144,4 +144,51 @@
channel.cancel(TestCancellationException())
channel.receiveOrNull()
}
+
+ @Test
+ fun testBufferSize() = runTest {
+ val capacity = 42
+ val channel = Channel<Int>(capacity)
+ checkBufferChannel(channel, capacity)
+ }
+
+ @Test
+ fun testBufferSizeFromTheMiddle() = runTest {
+ val capacity = 42
+ val channel = Channel<Int>(capacity)
+ repeat(4) {
+ channel.offer(-1)
+ }
+ repeat(4) {
+ channel.receiveOrNull()
+ }
+ checkBufferChannel(channel, capacity)
+ }
+
+ private suspend fun CoroutineScope.checkBufferChannel(
+ channel: Channel<Int>,
+ capacity: Int
+ ) {
+ launch {
+ expect(2)
+ repeat(42) {
+ channel.send(it)
+ }
+ expect(3)
+ channel.send(42)
+ expect(5)
+ channel.close()
+ }
+
+ expect(1)
+ yield()
+
+ expect(4)
+ val result = ArrayList<Int>(42)
+ channel.consumeEach {
+ result.add(it)
+ }
+ assertEquals((0..capacity).toList(), result)
+ finish(6)
+ }
}
diff --git a/kotlinx-coroutines-core/jvm/test/channels/ArrayChannelStressTest.kt b/kotlinx-coroutines-core/jvm/test/channels/ArrayChannelStressTest.kt
index ccb0e87..74dc24c 100644
--- a/kotlinx-coroutines-core/jvm/test/channels/ArrayChannelStressTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/channels/ArrayChannelStressTest.kt
@@ -22,13 +22,13 @@
fun testStress() = runTest {
val n = 100_000 * stressTestMultiplier
val q = Channel<Int>(capacity)
- val sender = launch(coroutineContext) {
+ val sender = launch {
for (i in 1..n) {
q.send(i)
}
expect(2)
}
- val receiver = launch(coroutineContext) {
+ val receiver = launch {
for (i in 1..n) {
val next = q.receive()
check(next == i)
@@ -40,4 +40,25 @@
receiver.join()
finish(4)
}
+
+ @Test
+ fun testBurst() = runTest {
+ Assume.assumeTrue(capacity < 100_000)
+ repeat(10_000 * stressTestMultiplier) {
+ val channel = Channel<Int>(capacity)
+ val sender = launch(Dispatchers.Default) {
+ for (i in 1..capacity * 2) {
+ channel.send(i)
+ }
+ }
+ val receiver = launch(Dispatchers.Default) {
+ for (i in 1..capacity * 2) {
+ val next = channel.receive()
+ check(next == i)
+ }
+ }
+ sender.join()
+ receiver.join()
+ }
+ }
}