Make the list of segments more abstract (#1563)
* Make the list of segments more abstract, so that it can be used for other synchronization and communication primitives
Co-authored-by: Roman Elizarov <elizarov@gmail.com>
diff --git a/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt
new file mode 100644
index 0000000..128a199
--- /dev/null
+++ b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt
@@ -0,0 +1,240 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
+package kotlinx.coroutines.internal
+
+import kotlinx.atomicfu.*
+import kotlinx.coroutines.*
+import kotlin.native.concurrent.SharedImmutable
+
+/**
+ * Returns the first segment `s` with `s.id >= id` or `CLOSED`
+ * if all the segments in this linked list have lower `id`, and the list is closed for further segment additions.
+ */
+private inline fun <S : Segment<S>> S.findSegmentInternal(
+ id: Long,
+ createNewSegment: (id: Long, prev: S?) -> S
+): SegmentOrClosed<S> {
+ /*
+ Go through `next` references and add new segments if needed, similarly to the `push` in the Michael-Scott
+ queue algorithm. The only difference is that "CAS failure" means that the required segment has already been
+ added, so the algorithm just uses it. This way, only one segment with each id can be added.
+ */
+ var cur: S = this
+ while (cur.id < id || cur.removed) {
+ val next = cur.nextOrIfClosed { return SegmentOrClosed(CLOSED) }
+ if (next != null) { // there is a next node -- move there
+ cur = next
+ continue
+ }
+ val newTail = createNewSegment(cur.id + 1, cur)
+ if (cur.trySetNext(newTail)) { // successfully added new node -- move there
+ if (cur.removed) cur.remove()
+ cur = newTail
+ }
+ }
+ return SegmentOrClosed(cur)
+}
+
+/**
+ * Returns `false` if the segment `to` is logically removed, `true` on a successful update.
+ */
+@Suppress("NOTHING_TO_INLINE") // Must be inline because it is an AtomicRef extension
+private inline fun <S : Segment<S>> AtomicRef<S>.moveForward(to: S): Boolean = loop { cur ->
+ if (cur.id >= to.id) return true
+ if (!to.tryIncPointers()) return false
+ if (compareAndSet(cur, to)) { // the segment is moved
+ if (cur.decPointers()) cur.remove()
+ return true
+ }
+ if (to.decPointers()) to.remove() // undo tryIncPointers
+}
+
+/**
+ * Tries to find a segment with the specified [id] following by next references from the
+ * [startFrom] segment and creating new ones if needed. The typical use-case is reading this `AtomicRef` values,
+ * doing some synchronization, and invoking this function to find the required segment and update the pointer.
+ * At the same time, [Segment.cleanPrev] should also be invoked if the previous segments are no longer needed
+ * (e.g., queues should use it in dequeue operations).
+ *
+ * Since segments can be removed from the list, or it can be closed for further segment additions.
+ * Returns the segment `s` with `s.id >= id` or `CLOSED` if all the segments in this linked list have lower `id`,
+ * and the list is closed.
+ */
+internal inline fun <S : Segment<S>> AtomicRef<S>.findSegmentAndMoveForward(
+ id: Long,
+ startFrom: S,
+ createNewSegment: (id: Long, prev: S?) -> S
+): SegmentOrClosed<S> {
+ while (true) {
+ val s = startFrom.findSegmentInternal(id, createNewSegment)
+ if (s.isClosed || moveForward(s.segment)) return s
+ }
+}
+
+/**
+ * Closes this linked list of nodes by forbidding adding new ones,
+ * returns the last node in the list.
+ */
+internal fun <N : ConcurrentLinkedListNode<N>> N.close(): N {
+ var cur: N = this
+ while (true) {
+ val next = cur.nextOrIfClosed { return cur }
+ if (next === null) {
+ if (cur.markAsClosed()) return cur
+ } else {
+ cur = next
+ }
+ }
+}
+
+internal abstract class ConcurrentLinkedListNode<N : ConcurrentLinkedListNode<N>>(prev: N?) {
+ // Pointer to the next node, updates similarly to the Michael-Scott queue algorithm.
+ private val _next = atomic<Any?>(null)
+ // Pointer to the previous node, updates in [remove] function.
+ private val _prev = atomic(prev)
+
+ private val nextOrClosed get() = _next.value
+
+ /**
+ * Returns the next segment or `null` of the one does not exist,
+ * and invokes [onClosedAction] if this segment is marked as closed.
+ */
+ @Suppress("UNCHECKED_CAST")
+ inline fun nextOrIfClosed(onClosedAction: () -> Nothing): N? = nextOrClosed.let {
+ if (it === CLOSED) {
+ onClosedAction()
+ } else {
+ it as N?
+ }
+ }
+
+ val next: N? get() = nextOrIfClosed { return null }
+
+ /**
+ * Tries to set the next segment if it is not specified and this segment is not marked as closed.
+ */
+ fun trySetNext(value: N): Boolean = _next.compareAndSet(null, value)
+
+ /**
+ * Checks whether this node is the physical tail of the current linked list.
+ */
+ val isTail: Boolean get() = next == null
+
+ val prev: N? get() = _prev.value
+
+ /**
+ * Cleans the pointer to the previous node.
+ */
+ fun cleanPrev() { _prev.lazySet(null) }
+
+ /**
+ * Tries to mark the linked list as closed by forbidding adding new nodes after this one.
+ */
+ fun markAsClosed() = _next.compareAndSet(null, CLOSED)
+
+ /**
+ * This property indicates whether the current node is logically removed.
+ * The expected use-case is removing the node logically (so that [removed] becomes true),
+ * and invoking [remove] after that. Note that this implementation relies on the contract
+ * that the physical tail cannot be logically removed. Please, do not break this contract;
+ * otherwise, memory leaks and unexpected behavior can occur.
+ */
+ abstract val removed: Boolean
+
+ /**
+ * Removes this node physically from this linked list. The node should be
+ * logically removed (so [removed] returns `true`) at the point of invocation.
+ */
+ fun remove() {
+ assert { removed } // The node should be logically removed at first.
+ assert { !isTail } // The physical tail cannot be removed.
+ while (true) {
+ // Read `next` and `prev` pointers ignoring logically removed nodes.
+ val prev = leftmostAliveNode
+ val next = rightmostAliveNode
+ // Link `next` and `prev`.
+ next._prev.value = prev
+ if (prev !== null) prev._next.value = next
+ // Checks that prev and next are still alive.
+ if (next.removed) continue
+ if (prev !== null && prev.removed) continue
+ // This node is removed.
+ return
+ }
+ }
+
+ private val leftmostAliveNode: N? get() {
+ var cur = prev
+ while (cur !== null && cur.removed)
+ cur = cur._prev.value
+ return cur
+ }
+
+ private val rightmostAliveNode: N get() {
+ assert { !isTail } // Should not be invoked on the tail node
+ var cur = next!!
+ while (cur.removed)
+ cur = cur.next!!
+ return cur
+ }
+}
+
+/**
+ * Each segment in the list has a unique id and is created by the provided to [findSegmentAndMoveForward] method.
+ * Essentially, this is a node in the Michael-Scott queue algorithm,
+ * but with maintaining [prev] pointer for efficient [remove] implementation.
+ */
+internal abstract class Segment<S : Segment<S>>(val id: Long, prev: S?, pointers: Int): ConcurrentLinkedListNode<S>(prev) {
+ /**
+ * This property should return the maximal number of slots in this segment,
+ * it is used to define whether the segment is logically removed.
+ */
+ abstract val maxSlots: Int
+
+ /**
+ * Numbers of cleaned slots (the lowest bits) and AtomicRef pointers to this segment (the highest bits)
+ */
+ private val cleanedAndPointers = atomic(pointers shl POINTERS_SHIFT)
+
+ /**
+ * The segment is considered as removed if all the slots are cleaned.
+ * There are no pointers to this segment from outside, and
+ * it is not a physical tail in the linked list of segments.
+ */
+ override val removed get() = cleanedAndPointers.value == maxSlots && !isTail
+
+ // increments the number of pointers if this segment is not logically removed.
+ internal fun tryIncPointers() = cleanedAndPointers.addConditionally(1 shl POINTERS_SHIFT) { it != maxSlots || isTail }
+
+ // returns `true` if this segment is logically removed after the decrement.
+ internal fun decPointers() = cleanedAndPointers.addAndGet(-(1 shl POINTERS_SHIFT)) == maxSlots && !isTail
+
+ /**
+ * Invoked on each slot clean-up; should not be invoked twice for the same slot.
+ */
+ fun onSlotCleaned() {
+ if (cleanedAndPointers.incrementAndGet() == maxSlots && !isTail) remove()
+ }
+}
+
+private inline fun AtomicInt.addConditionally(delta: Int, condition: (cur: Int) -> Boolean): Boolean {
+ while (true) {
+ val cur = this.value
+ if (!condition(cur)) return false
+ if (this.compareAndSet(cur, cur + delta)) return true
+ }
+}
+
+@Suppress("EXPERIMENTAL_FEATURE_WARNING") // We are using inline class only internally, so it is Ok
+internal inline class SegmentOrClosed<S : Segment<S>>(private val value: Any?) {
+ val isClosed: Boolean get() = value === CLOSED
+ @Suppress("UNCHECKED_CAST")
+ val segment: S get() = if (value === CLOSED) error("Does not contain segment") else value as S
+}
+
+private const val POINTERS_SHIFT = 16
+
+@SharedImmutable
+private val CLOSED = Symbol("CLOSED")
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt b/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt
deleted file mode 100644
index 0091d13..0000000
--- a/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
- */
-
-package kotlinx.coroutines.internal
-
-import kotlinx.atomicfu.*
-import kotlinx.coroutines.*
-
-/**
- * Essentially, this segment queue is an infinite array of segments, which is represented as
- * a Michael-Scott queue of them. All segments are instances of [Segment] class and
- * follow in natural order (see [Segment.id]) in the queue.
- */
-internal abstract class SegmentQueue<S: Segment<S>>() {
- private val _head: AtomicRef<S>
- /**
- * Returns the first segment in the queue.
- */
- protected val head: S get() = _head.value
-
- private val _tail: AtomicRef<S>
- /**
- * Returns the last segment in the queue.
- */
- protected val tail: S get() = _tail.value
-
- init {
- val initialSegment = newSegment(0)
- _head = atomic(initialSegment)
- _tail = atomic(initialSegment)
- }
-
- /**
- * The implementation should create an instance of segment [S] with the specified id
- * and initial reference to the previous one.
- */
- abstract fun newSegment(id: Long, prev: S? = null): S
-
- /**
- * Finds a segment with the specified [id] following by next references from the
- * [startFrom] segment. The typical use-case is reading [tail] or [head], doing some
- * synchronization, and invoking [getSegment] or [getSegmentAndMoveHead] correspondingly
- * to find the required segment.
- */
- protected fun getSegment(startFrom: S, id: Long): S? {
- // Go through `next` references and add new segments if needed,
- // similarly to the `push` in the Michael-Scott queue algorithm.
- // The only difference is that `CAS failure` means that the
- // required segment has already been added, so the algorithm just
- // uses it. This way, only one segment with each id can be in the queue.
- var cur: S = startFrom
- while (cur.id < id) {
- var curNext = cur.next
- if (curNext == null) {
- // Add a new segment.
- val newTail = newSegment(cur.id + 1, cur)
- curNext = if (cur.casNext(null, newTail)) {
- if (cur.removed) {
- cur.remove()
- }
- moveTailForward(newTail)
- newTail
- } else {
- cur.next!!
- }
- }
- cur = curNext
- }
- if (cur.id != id) return null
- return cur
- }
-
- /**
- * Invokes [getSegment] and replaces [head] with the result if its [id] is greater.
- */
- protected fun getSegmentAndMoveHead(startFrom: S, id: Long): S? {
- @Suppress("LeakingThis")
- if (startFrom.id == id) return startFrom
- val s = getSegment(startFrom, id) ?: return null
- moveHeadForward(s)
- return s
- }
-
- /**
- * Updates [head] to the specified segment
- * if its `id` is greater.
- */
- private fun moveHeadForward(new: S) {
- _head.loop { curHead ->
- if (curHead.id > new.id) return
- if (_head.compareAndSet(curHead, new)) {
- new.prev.value = null
- return
- }
- }
- }
-
- /**
- * Updates [tail] to the specified segment
- * if its `id` is greater.
- */
- private fun moveTailForward(new: S) {
- _tail.loop { curTail ->
- if (curTail.id > new.id) return
- if (_tail.compareAndSet(curTail, new)) return
- }
- }
-}
-
-/**
- * Each segment in [SegmentQueue] has a unique id and is created by [SegmentQueue.newSegment].
- * Essentially, this is a node in the Michael-Scott queue algorithm, but with
- * maintaining [prev] pointer for efficient [remove] implementation.
- */
-internal abstract class Segment<S: Segment<S>>(val id: Long, prev: S?) {
- // Pointer to the next segment, updates similarly to the Michael-Scott queue algorithm.
- private val _next = atomic<S?>(null)
- val next: S? get() = _next.value
- fun casNext(expected: S?, value: S?): Boolean = _next.compareAndSet(expected, value)
- // Pointer to the previous segment, updates in [remove] function.
- val prev = atomic<S?>(null)
-
- /**
- * Returns `true` if this segment is logically removed from the queue.
- * The [remove] function should be called right after it becomes logically removed.
- */
- abstract val removed: Boolean
-
- init {
- this.prev.value = prev
- }
-
- /**
- * Removes this segment physically from the segment queue. The segment should be
- * logically removed (so [removed] returns `true`) at the point of invocation.
- */
- fun remove() {
- assert { removed } // The segment should be logically removed at first
- // Read `next` and `prev` pointers.
- var next = this._next.value ?: return // tail cannot be removed
- var prev = prev.value ?: return // head cannot be removed
- // Link `next` and `prev`.
- prev.moveNextToRight(next)
- while (prev.removed) {
- prev = prev.prev.value ?: break
- prev.moveNextToRight(next)
- }
- next.movePrevToLeft(prev)
- while (next.removed) {
- next = next.next ?: break
- next.movePrevToLeft(prev)
- }
- }
-
- /**
- * Updates [next] pointer to the specified segment if
- * the [id] of the specified segment is greater.
- */
- private fun moveNextToRight(next: S) {
- while (true) {
- val curNext = this._next.value as S
- if (next.id <= curNext.id) return
- if (this._next.compareAndSet(curNext, next)) return
- }
- }
-
- /**
- * Updates [prev] pointer to the specified segment if
- * the [id] of the specified segment is lower.
- */
- private fun movePrevToLeft(prev: S) {
- while (true) {
- val curPrev = this.prev.value ?: return
- if (curPrev.id <= prev.id) return
- if (this.prev.compareAndSet(curPrev, prev)) return
- }
- }
-}
diff --git a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt
index aa7ed63..7cdc736 100644
--- a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt
+++ b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt
@@ -8,9 +8,8 @@
import kotlinx.coroutines.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
-import kotlin.jvm.*
import kotlin.math.*
-import kotlin.native.concurrent.*
+import kotlin.native.concurrent.SharedImmutable
/**
* A counting semaphore for coroutines that logically maintains a number of available permits.
@@ -84,16 +83,26 @@
}
}
-private class SemaphoreImpl(
- private val permits: Int, acquiredPermits: Int
-) : Semaphore, SegmentQueue<SemaphoreSegment>() {
+private class SemaphoreImpl(private val permits: Int, acquiredPermits: Int) : Semaphore {
+
+ // The queue of waiting acquirers is essentially an infinite array based on the list of segments
+ // (see `SemaphoreSegment`); each segment contains a fixed number of slots. To determine a slot for each enqueue
+ // and dequeue operation, we increment the corresponding counter at the beginning of the operation
+ // and use the value before the increment as a slot number. This way, each enqueue-dequeue pair
+ // works with an individual cell.We use the corresponding segment pointer to find the required ones.
+ private val head: AtomicRef<SemaphoreSegment>
+ private val deqIdx = atomic(0L)
+ private val tail: AtomicRef<SemaphoreSegment>
+ private val enqIdx = atomic(0L)
+
init {
require(permits > 0) { "Semaphore should have at least 1 permit, but had $permits" }
require(acquiredPermits in 0..permits) { "The number of acquired permits should be in 0..$permits" }
+ val s = SemaphoreSegment(0, null, 2)
+ head = atomic(s)
+ tail = atomic(s)
}
- override fun newSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev)
-
/**
* This counter indicates a number of available permits if it is non-negative,
* or the size with minus sign otherwise. Note, that 32-bit counter is enough here
@@ -104,14 +113,6 @@
private val _availablePermits = atomic(permits - acquiredPermits)
override val availablePermits: Int get() = max(_availablePermits.value, 0)
- // The queue of waiting acquirers is essentially an infinite array based on `SegmentQueue`;
- // each segment contains a fixed number of slots. To determine a slot for each enqueue
- // and dequeue operation, we increment the corresponding counter at the beginning of the operation
- // and use the value before the increment as a slot number. This way, each enqueue-dequeue pair
- // works with an individual cell.
- private val enqIdx = atomic(0L)
- private val deqIdx = atomic(0L)
-
override fun tryAcquire(): Boolean {
_availablePermits.loop { p ->
if (p <= 0) return false
@@ -136,12 +137,13 @@
cur + 1
}
- private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutineReusable<Unit> sc@ { cont ->
- val last = this.tail
+ private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutineReusable<Unit> sc@{ cont ->
+ val curTail = this.tail.value
val enqIdx = enqIdx.getAndIncrement()
- val segment = getSegment(last, enqIdx / SEGMENT_SIZE)
+ val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail,
+ createNewSegment = ::createSegment).run { segment } // cannot be closed
val i = (enqIdx % SEGMENT_SIZE).toInt()
- if (segment === null || segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
+ if (segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
// already resumed
cont.resume(Unit)
return@sc
@@ -151,10 +153,17 @@
@Suppress("UNCHECKED_CAST")
internal fun resumeNextFromQueue() {
- try_again@while (true) {
- val first = this.head
+ try_again@ while (true) {
+ val curHead = this.head.value
val deqIdx = deqIdx.getAndIncrement()
- val segment = getSegmentAndMoveHead(first, deqIdx / SEGMENT_SIZE) ?: continue@try_again
+ val id = deqIdx / SEGMENT_SIZE
+ val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
+ createNewSegment = ::createSegment).run { segment } // cannot be closed
+ segment.cleanPrev()
+ if (segment.id > id) {
+ this.deqIdx.updateIfLower(segment.id * SEGMENT_SIZE)
+ continue@try_again
+ }
val i = (deqIdx % SEGMENT_SIZE).toInt()
val cont = segment.getAndSet(i, RESUMED)
if (cont === null) return // just resumed
@@ -165,6 +174,10 @@
}
}
+private inline fun AtomicLong.updateIfLower(value: Long): Unit = loop { cur ->
+ if (cur >= value || compareAndSet(cur, value)) return
+}
+
private class CancelSemaphoreAcquisitionHandler(
private val semaphore: SemaphoreImpl,
private val segment: SemaphoreSegment,
@@ -180,10 +193,11 @@
override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $index]"
}
-private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment<SemaphoreSegment>(id, prev) {
+private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)
+
+private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) : Segment<SemaphoreSegment>(id, prev, pointers) {
val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
- private val cancelledSlots = atomic(0)
- override val removed get() = cancelledSlots.value == SEGMENT_SIZE
+ override val maxSlots: Int get() = SEGMENT_SIZE
@Suppress("NOTHING_TO_INLINE")
inline fun get(index: Int): Any? = acquirers[index].value
@@ -200,8 +214,7 @@
// Try to cancel the slot
val cancelled = getAndSet(index, CANCELLED) !== RESUMED
// Remove this segment if needed
- if (cancelledSlots.incrementAndGet() == SEGMENT_SIZE)
- remove()
+ onSlotCleaned()
return cancelled
}
diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt
index 293be7a..3d1305c 100644
--- a/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt
+++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt
@@ -1,9 +1,13 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
package kotlinx.coroutines.internal
-import kotlinx.atomicfu.atomic
+import kotlinx.atomicfu.*
/**
- * This queue implementation is based on [SegmentQueue] for testing purposes and is organized as follows. Essentially,
+ * This queue implementation is based on [SegmentList] for testing purposes and is organized as follows. Essentially,
* the [SegmentBasedQueue] is represented as an infinite array of segments, each stores one element (see [OneElementSegment]).
* Both [enqueue] and [dequeue] operations increment the corresponding global index ([enqIdx] for [enqueue] and
* [deqIdx] for [dequeue]) and work with the indexed by this counter cell. Since both operations increment the indices
@@ -13,60 +17,89 @@
* the cell with [BROKEN] token and retry the operation, [enqueue] at the same time should restart as well; this way,
* the queue is obstruction-free.
*/
-internal class SegmentBasedQueue<T> : SegmentQueue<OneElementSegment<T>>() {
- override fun newSegment(id: Long, prev: OneElementSegment<T>?): OneElementSegment<T> = OneElementSegment(id, prev)
+internal class SegmentBasedQueue<T> {
+ private val head: AtomicRef<OneElementSegment<T>>
+ private val tail: AtomicRef<OneElementSegment<T>>
private val enqIdx = atomic(0L)
private val deqIdx = atomic(0L)
- // Returns the segments associated with the enqueued element.
- fun enqueue(element: T): OneElementSegment<T> {
+ init {
+ val s = OneElementSegment<T>(0, null, 2)
+ head = atomic(s)
+ tail = atomic(s)
+ }
+
+ // Returns the segments associated with the enqueued element, or `null` if the queue is closed.
+ fun enqueue(element: T): OneElementSegment<T>? {
while (true) {
- var tail = this.tail
+ val curTail = this.tail.value
val enqIdx = this.enqIdx.getAndIncrement()
- tail = getSegment(tail, enqIdx) ?: continue
- if (tail.element.value === BROKEN) continue
- if (tail.element.compareAndSet(null, element)) return tail
+ val segmentOrClosed = this.tail.findSegmentAndMoveForward(id = enqIdx, startFrom = curTail, createNewSegment = ::createSegment)
+ if (segmentOrClosed.isClosed) return null
+ val s = segmentOrClosed.segment
+ if (s.element.value === BROKEN) continue
+ if (s.element.compareAndSet(null, element)) return s
}
}
fun dequeue(): T? {
while (true) {
if (this.deqIdx.value >= this.enqIdx.value) return null
- var firstSegment = this.head
+ val curHead = this.head.value
val deqIdx = this.deqIdx.getAndIncrement()
- firstSegment = getSegmentAndMoveHead(firstSegment, deqIdx) ?: continue
- var el = firstSegment.element.value
+ val segmentOrClosed = this.head.findSegmentAndMoveForward(id = deqIdx, startFrom = curHead, createNewSegment = ::createSegment)
+ if (segmentOrClosed.isClosed) return null
+ val s = segmentOrClosed.segment
+ s.cleanPrev()
+ if (s.id > deqIdx) continue
+ var el = s.element.value
if (el === null) {
- if (firstSegment.element.compareAndSet(null, BROKEN)) continue
- else el = firstSegment.element.value
+ if (s.element.compareAndSet(null, BROKEN)) continue
+ else el = s.element.value
}
- if (el === REMOVED) continue
+ if (el === BROKEN) continue
+ @Suppress("UNCHECKED_CAST")
return el as T
}
}
+ // `enqueue` should return `null` after the queue is closed
+ fun close(): OneElementSegment<T> {
+ val s = this.tail.value.close()
+ var cur = s
+ while (true) {
+ cur.element.compareAndSet(null, BROKEN)
+ cur = cur.prev ?: break
+ }
+ return s
+ }
+
val numberOfSegments: Int get() {
- var s: OneElementSegment<T>? = head
- var i = 0
- while (s != null) {
- s = s.next
+ var cur = head.value
+ var i = 1
+ while (true) {
+ cur = cur.next ?: return i
i++
}
- return i
+ }
+
+ fun checkHeadPrevIsCleaned() {
+ check(head.value.prev === null)
}
}
-internal class OneElementSegment<T>(id: Long, prev: OneElementSegment<T>?) : Segment<OneElementSegment<T>>(id, prev) {
+private fun <T> createSegment(id: Long, prev: OneElementSegment<T>?) = OneElementSegment(id, prev, 0)
+
+internal class OneElementSegment<T>(id: Long, prev: OneElementSegment<T>?, pointers: Int) : Segment<OneElementSegment<T>>(id, prev, pointers) {
val element = atomic<Any?>(null)
- override val removed get() = element.value === REMOVED
+ override val maxSlots get() = 1
fun removeSegment() {
- element.value = REMOVED
- remove()
+ element.value = BROKEN
+ onSlotCleaned()
}
}
-private val BROKEN = Symbol("BROKEN")
-private val REMOVED = Symbol("REMOVED")
\ No newline at end of file
+private val BROKEN = Symbol("BROKEN")
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentListTest.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentListTest.kt
new file mode 100644
index 0000000..ff6a346
--- /dev/null
+++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentListTest.kt
@@ -0,0 +1,41 @@
+package kotlinx.coroutines.internal
+
+import kotlinx.atomicfu.*
+import org.junit.Test
+import kotlin.test.*
+
+class SegmentListTest {
+ @Test
+ fun testRemoveTail() {
+ val initialSegment = TestSegment(0, null, 2)
+ val head = AtomicRefHolder(initialSegment)
+ val tail = AtomicRefHolder(initialSegment)
+ val s1 = tail.ref.findSegmentAndMoveForward(1, tail.ref.value, ::createTestSegment).segment
+ assertFalse(s1.removed)
+ tail.ref.value.onSlotCleaned()
+ assertFalse(s1.removed)
+ head.ref.findSegmentAndMoveForward(2, head.ref.value, ::createTestSegment)
+ assertFalse(s1.removed)
+ tail.ref.findSegmentAndMoveForward(2, head.ref.value, ::createTestSegment)
+ assertTrue(s1.removed)
+ }
+
+ @Test
+ fun testClose() {
+ val initialSegment = TestSegment(0, null, 2)
+ val head = AtomicRefHolder(initialSegment)
+ val tail = AtomicRefHolder(initialSegment)
+ tail.ref.findSegmentAndMoveForward(1, tail.ref.value, ::createTestSegment)
+ assertEquals(tail.ref.value, tail.ref.value.close())
+ assertTrue(head.ref.findSegmentAndMoveForward(2, head.ref.value, ::createTestSegment).isClosed)
+ }
+}
+
+private class AtomicRefHolder<T>(initialValue: T) {
+ val ref = atomic(initialValue)
+}
+
+private class TestSegment(id: Long, prev: TestSegment?, pointers: Int) : Segment<TestSegment>(id, prev, pointers) {
+ override val maxSlots: Int get() = 1
+}
+private fun createTestSegment(id: Long, prev: TestSegment?) = TestSegment(id, prev, 0)
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt
index b59a648..fd2d329 100644
--- a/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt
@@ -33,7 +33,7 @@
val s = q.enqueue(2)
q.enqueue(3)
assertEquals(3, q.numberOfSegments)
- s.removeSegment()
+ s!!.removeSegment()
assertEquals(2, q.numberOfSegments)
assertEquals(1, q.dequeue())
assertEquals(3, q.dequeue())
@@ -47,12 +47,22 @@
val s = q.enqueue(2)
assertEquals(1, q.dequeue())
q.enqueue(3)
- s.removeSegment()
+ s!!.removeSegment()
assertEquals(3, q.dequeue())
assertNull(q.dequeue())
}
@Test
+ fun testClose() {
+ val q = SegmentBasedQueue<Int>()
+ q.enqueue(1)
+ assertEquals(0, q.close().id)
+ assertEquals(null, q.enqueue(2))
+ assertEquals(1, q.dequeue())
+ assertEquals(null, q.dequeue())
+ }
+
+ @Test
fun stressTest() {
val q = SegmentBasedQueue<Int>()
val expectedQueue = ArrayDeque<Int>()
@@ -64,6 +74,7 @@
expectedQueue.add(el)
} else { // remove
assertEquals(expectedQueue.poll(), q.dequeue())
+ q.checkHeadPrevIsCleaned()
}
}
}
@@ -78,7 +89,7 @@
val N = 100_000 * stressTestMultiplier
val T = 10
val q = SegmentBasedQueue<Int>()
- val segments = (1..N).map { q.enqueue(it) }.toMutableList()
+ val segments = (1..N).map { q.enqueue(it)!! }.toMutableList()
if (random) segments.shuffle()
assertEquals(N, q.numberOfSegments)
val nextSegmentIndex = AtomicInteger()
diff --git a/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt b/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt
index 1bb51a5..89bf8df 100644
--- a/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt
@@ -18,12 +18,17 @@
private val q = SegmentBasedQueue<Int>()
@Operation
- fun add(@Param(name = "value") value: Int) {
- q.enqueue(value)
+ fun enqueue(@Param(name = "value") x: Int): Boolean {
+ return q.enqueue(x) !== null
}
@Operation
- fun poll(): Int? = q.dequeue()
+ fun dequeue(): Int? = q.dequeue()
+
+ @Operation
+ fun close() {
+ q.close()
+ }
override fun extractState(): Any {
val elements = ArrayList<Int>()
@@ -31,8 +36,8 @@
val x = q.dequeue() ?: break
elements.add(x)
}
-
- return elements
+ val closed = q.enqueue(0) === null
+ return elements to closed
}
@Test