Make the list of segments more abstract (#1563)

* Make the list of segments more abstract, so that it can be used for other synchronization and communication primitives

Co-authored-by: Roman Elizarov <elizarov@gmail.com>
diff --git a/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt
new file mode 100644
index 0000000..128a199
--- /dev/null
+++ b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt
@@ -0,0 +1,240 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
+package kotlinx.coroutines.internal
+
+import kotlinx.atomicfu.*
+import kotlinx.coroutines.*
+import kotlin.native.concurrent.SharedImmutable
+
+/**
+ * Returns the first segment `s` with `s.id >= id` or `CLOSED`
+ * if all the segments in this linked list have lower `id`, and the list is closed for further segment additions.
+ */
+private inline fun <S : Segment<S>> S.findSegmentInternal(
+    id: Long,
+    createNewSegment: (id: Long, prev: S?) -> S
+): SegmentOrClosed<S> {
+    /*
+       Go through `next` references and add new segments if needed, similarly to the `push` in the Michael-Scott
+       queue algorithm. The only difference is that "CAS failure" means that the required segment has already been
+       added, so the algorithm just uses it. This way, only one segment with each id can be added.
+     */
+    var cur: S = this
+    while (cur.id < id || cur.removed) {
+        val next = cur.nextOrIfClosed { return SegmentOrClosed(CLOSED) }
+        if (next != null) { // there is a next node -- move there
+            cur = next
+            continue
+        }
+        val newTail = createNewSegment(cur.id + 1, cur)
+        if (cur.trySetNext(newTail)) { // successfully added new node -- move there
+            if (cur.removed) cur.remove()
+            cur = newTail
+        }
+    }
+    return SegmentOrClosed(cur)
+}
+
+/**
+ * Returns `false` if the segment `to` is logically removed, `true` on a successful update.
+ */
+@Suppress("NOTHING_TO_INLINE") // Must be inline because it is an AtomicRef extension
+private inline fun <S : Segment<S>> AtomicRef<S>.moveForward(to: S): Boolean = loop { cur ->
+    if (cur.id >= to.id) return true
+    if (!to.tryIncPointers()) return false
+    if (compareAndSet(cur, to)) { // the segment is moved
+        if (cur.decPointers()) cur.remove()
+        return true
+    }
+    if (to.decPointers()) to.remove() // undo tryIncPointers
+}
+
+/**
+ * Tries to find a segment with the specified [id] following by next references from the
+ * [startFrom] segment and creating new ones if needed. The typical use-case is reading this `AtomicRef` values,
+ * doing some synchronization, and invoking this function to find the required segment and update the pointer.
+ * At the same time, [Segment.cleanPrev] should also be invoked if the previous segments are no longer needed
+ * (e.g., queues should use it in dequeue operations).
+ *
+ * Since segments can be removed from the list, or it can be closed for further segment additions.
+ * Returns the segment `s` with `s.id >= id` or `CLOSED` if all the segments in this linked list have lower `id`,
+ * and the list is closed.
+ */
+internal inline fun <S : Segment<S>> AtomicRef<S>.findSegmentAndMoveForward(
+    id: Long,
+    startFrom: S,
+    createNewSegment: (id: Long, prev: S?) -> S
+): SegmentOrClosed<S> {
+    while (true) {
+        val s = startFrom.findSegmentInternal(id, createNewSegment)
+        if (s.isClosed || moveForward(s.segment)) return s
+    }
+}
+
+/**
+ * Closes this linked list of nodes by forbidding adding new ones,
+ * returns the last node in the list.
+ */
+internal fun <N : ConcurrentLinkedListNode<N>> N.close(): N {
+    var cur: N = this
+    while (true) {
+        val next = cur.nextOrIfClosed { return cur }
+        if (next === null) {
+            if (cur.markAsClosed()) return cur
+        } else {
+            cur = next
+        }
+    }
+}
+
+internal abstract class ConcurrentLinkedListNode<N : ConcurrentLinkedListNode<N>>(prev: N?) {
+    // Pointer to the next node, updates similarly to the Michael-Scott queue algorithm.
+    private val _next = atomic<Any?>(null)
+    // Pointer to the previous node, updates in [remove] function.
+    private val _prev = atomic(prev)
+
+    private val nextOrClosed get() = _next.value
+
+    /**
+     * Returns the next segment or `null` of the one does not exist,
+     * and invokes [onClosedAction] if this segment is marked as closed.
+     */
+    @Suppress("UNCHECKED_CAST")
+    inline fun nextOrIfClosed(onClosedAction: () -> Nothing): N? = nextOrClosed.let {
+        if (it === CLOSED) {
+            onClosedAction()
+        } else {
+            it as N?
+        }
+    }
+
+    val next: N? get() = nextOrIfClosed { return null }
+
+    /**
+     * Tries to set the next segment if it is not specified and this segment is not marked as closed.
+     */
+    fun trySetNext(value: N): Boolean = _next.compareAndSet(null, value)
+
+    /**
+     * Checks whether this node is the physical tail of the current linked list.
+     */
+    val isTail: Boolean get() = next == null
+
+    val prev: N? get() = _prev.value
+
+    /**
+     * Cleans the pointer to the previous node.
+     */
+    fun cleanPrev() { _prev.lazySet(null) }
+
+    /**
+     * Tries to mark the linked list as closed by forbidding adding new nodes after this one.
+     */
+    fun markAsClosed() = _next.compareAndSet(null, CLOSED)
+
+    /**
+     * This property indicates whether the current node is logically removed.
+     * The expected use-case is removing the node logically (so that [removed] becomes true),
+     * and invoking [remove] after that. Note that this implementation relies on the contract
+     * that the physical tail cannot be logically removed. Please, do not break this contract;
+     * otherwise, memory leaks and unexpected behavior can occur.
+     */
+    abstract val removed: Boolean
+
+    /**
+     * Removes this node physically from this linked list. The node should be
+     * logically removed (so [removed] returns `true`) at the point of invocation.
+     */
+    fun remove() {
+        assert { removed } // The node should be logically removed at first.
+        assert { !isTail } // The physical tail cannot be removed.
+        while (true) {
+            // Read `next` and `prev` pointers ignoring logically removed nodes.
+            val prev = leftmostAliveNode
+            val next = rightmostAliveNode
+            // Link `next` and `prev`.
+            next._prev.value = prev
+            if (prev !== null) prev._next.value = next
+            // Checks that prev and next are still alive.
+            if (next.removed) continue
+            if (prev !== null && prev.removed) continue
+            // This node is removed.
+            return
+        }
+    }
+
+    private val leftmostAliveNode: N? get() {
+        var cur = prev
+        while (cur !== null && cur.removed)
+            cur = cur._prev.value
+        return cur
+    }
+
+    private val rightmostAliveNode: N get() {
+        assert { !isTail } // Should not be invoked on the tail node
+        var cur = next!!
+        while (cur.removed)
+            cur = cur.next!!
+        return cur
+    }
+}
+
+/**
+ * Each segment in the list has a unique id and is created by the provided to [findSegmentAndMoveForward] method.
+ * Essentially, this is a node in the Michael-Scott queue algorithm,
+ * but with maintaining [prev] pointer for efficient [remove] implementation.
+ */
+internal abstract class Segment<S : Segment<S>>(val id: Long, prev: S?, pointers: Int): ConcurrentLinkedListNode<S>(prev) {
+    /**
+     * This property should return the maximal number of slots in this segment,
+     * it is used to define whether the segment is logically removed.
+     */
+    abstract val maxSlots: Int
+
+    /**
+     * Numbers of cleaned slots (the lowest bits) and AtomicRef pointers to this segment (the highest bits)
+     */
+    private val cleanedAndPointers = atomic(pointers shl POINTERS_SHIFT)
+
+    /**
+     * The segment is considered as removed if all the slots are cleaned.
+     * There are no pointers to this segment from outside, and
+     * it is not a physical tail in the linked list of segments.
+     */
+    override val removed get() = cleanedAndPointers.value == maxSlots && !isTail
+
+    // increments the number of pointers if this segment is not logically removed.
+    internal fun tryIncPointers() = cleanedAndPointers.addConditionally(1 shl POINTERS_SHIFT) { it != maxSlots || isTail }
+
+    // returns `true` if this segment is logically removed after the decrement.
+    internal fun decPointers() = cleanedAndPointers.addAndGet(-(1 shl POINTERS_SHIFT)) == maxSlots && !isTail
+
+    /**
+     * Invoked on each slot clean-up; should not be invoked twice for the same slot.
+     */
+    fun onSlotCleaned() {
+        if (cleanedAndPointers.incrementAndGet() == maxSlots && !isTail) remove()
+    }
+}
+
+private inline fun AtomicInt.addConditionally(delta: Int, condition: (cur: Int) -> Boolean): Boolean {
+    while (true) {
+        val cur = this.value
+        if (!condition(cur)) return false
+        if (this.compareAndSet(cur, cur + delta)) return true
+    }
+}
+
+@Suppress("EXPERIMENTAL_FEATURE_WARNING") // We are using inline class only internally, so it is Ok
+internal inline class SegmentOrClosed<S : Segment<S>>(private val value: Any?) {
+    val isClosed: Boolean get() = value === CLOSED
+    @Suppress("UNCHECKED_CAST")
+    val segment: S get() = if (value === CLOSED) error("Does not contain segment") else value as S
+}
+
+private const val POINTERS_SHIFT = 16
+
+@SharedImmutable
+private val CLOSED = Symbol("CLOSED")
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt b/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt
deleted file mode 100644
index 0091d13..0000000
--- a/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
- */
-
-package kotlinx.coroutines.internal
-
-import kotlinx.atomicfu.*
-import kotlinx.coroutines.*
-
-/**
- * Essentially, this segment queue is an infinite array of segments, which is represented as
- * a Michael-Scott queue of them. All segments are instances of [Segment] class and
- * follow in natural order (see [Segment.id]) in the queue.
- */
-internal abstract class SegmentQueue<S: Segment<S>>() {
-    private val _head: AtomicRef<S>
-    /**
-     * Returns the first segment in the queue.
-     */
-    protected val head: S get() = _head.value
-
-    private val _tail: AtomicRef<S>
-    /**
-     * Returns the last segment in the queue.
-     */
-    protected val tail: S get() = _tail.value
-
-    init {
-        val initialSegment = newSegment(0)
-        _head = atomic(initialSegment)
-        _tail = atomic(initialSegment)
-    }
-
-    /**
-     * The implementation should create an instance of segment [S] with the specified id
-     * and initial reference to the previous one.
-     */
-    abstract fun newSegment(id: Long, prev: S? = null): S
-
-    /**
-     * Finds a segment with the specified [id] following by next references from the
-     * [startFrom] segment. The typical use-case is reading [tail] or [head], doing some
-     * synchronization, and invoking [getSegment] or [getSegmentAndMoveHead] correspondingly
-     * to find the required segment.
-     */
-    protected fun getSegment(startFrom: S, id: Long): S? {
-        // Go through `next` references and add new segments if needed,
-        // similarly to the `push` in the Michael-Scott queue algorithm.
-        // The only difference is that `CAS failure` means that the
-        // required segment has already been added, so the algorithm just
-        // uses it. This way, only one segment with each id can be in the queue.
-        var cur: S = startFrom
-        while (cur.id < id) {
-            var curNext = cur.next
-            if (curNext == null) {
-                // Add a new segment.
-                val newTail = newSegment(cur.id + 1, cur)
-                curNext = if (cur.casNext(null, newTail)) {
-                    if (cur.removed) {
-                        cur.remove()
-                    }
-                    moveTailForward(newTail)
-                    newTail
-                } else {
-                    cur.next!!
-                }
-            }
-            cur = curNext
-        }
-        if (cur.id != id) return null
-        return cur
-    }
-
-    /**
-     * Invokes [getSegment] and replaces [head] with the result if its [id] is greater.
-     */
-    protected fun getSegmentAndMoveHead(startFrom: S, id: Long): S? {
-        @Suppress("LeakingThis")
-        if (startFrom.id == id) return startFrom
-        val s = getSegment(startFrom, id) ?: return null
-        moveHeadForward(s)
-        return s
-    }
-
-    /**
-     * Updates [head] to the specified segment
-     * if its `id` is greater.
-     */
-    private fun moveHeadForward(new: S) {
-        _head.loop { curHead ->
-            if (curHead.id > new.id) return
-            if (_head.compareAndSet(curHead, new)) {
-                new.prev.value = null
-                return
-            }
-        }
-    }
-
-    /**
-     * Updates [tail] to the specified segment
-     * if its `id` is greater.
-     */
-    private fun moveTailForward(new: S) {
-        _tail.loop { curTail ->
-            if (curTail.id > new.id) return
-            if (_tail.compareAndSet(curTail, new)) return
-        }
-    }
-}
-
-/**
- * Each segment in [SegmentQueue] has a unique id and is created by [SegmentQueue.newSegment].
- * Essentially, this is a node in the Michael-Scott queue algorithm, but with
- * maintaining [prev] pointer for efficient [remove] implementation.
- */
-internal abstract class Segment<S: Segment<S>>(val id: Long, prev: S?) {
-    // Pointer to the next segment, updates similarly to the Michael-Scott queue algorithm.
-    private val _next = atomic<S?>(null)
-    val next: S? get() = _next.value
-    fun casNext(expected: S?, value: S?): Boolean = _next.compareAndSet(expected, value)
-    // Pointer to the previous segment, updates in [remove] function.
-    val prev = atomic<S?>(null)
-
-    /**
-     * Returns `true` if this segment is logically removed from the queue.
-     * The [remove] function should be called right after it becomes logically removed.
-     */
-    abstract val removed: Boolean
-
-    init {
-        this.prev.value = prev
-    }
-
-    /**
-     * Removes this segment physically from the segment queue. The segment should be
-     * logically removed (so [removed] returns `true`) at the point of invocation.
-     */
-    fun remove() {
-        assert { removed } // The segment should be logically removed at first
-        // Read `next` and `prev` pointers.
-        var next = this._next.value ?: return // tail cannot be removed
-        var prev = prev.value ?: return // head cannot be removed
-        // Link `next` and `prev`.
-        prev.moveNextToRight(next)
-        while (prev.removed) {
-            prev = prev.prev.value ?: break
-            prev.moveNextToRight(next)
-        }
-        next.movePrevToLeft(prev)
-        while (next.removed) {
-            next = next.next ?: break
-            next.movePrevToLeft(prev)
-        }
-    }
-
-    /**
-     * Updates [next] pointer to the specified segment if
-     * the [id] of the specified segment is greater.
-     */
-    private fun moveNextToRight(next: S) {
-        while (true) {
-            val curNext = this._next.value as S
-            if (next.id <= curNext.id) return
-            if (this._next.compareAndSet(curNext, next)) return
-        }
-    }
-
-    /**
-     * Updates [prev] pointer to the specified segment if
-     * the [id] of the specified segment is lower.
-     */
-    private fun movePrevToLeft(prev: S) {
-        while (true) {
-            val curPrev = this.prev.value ?: return
-            if (curPrev.id <= prev.id) return
-            if (this.prev.compareAndSet(curPrev, prev)) return
-        }
-    }
-}
diff --git a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt
index aa7ed63..7cdc736 100644
--- a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt
+++ b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt
@@ -8,9 +8,8 @@
 import kotlinx.coroutines.*
 import kotlinx.coroutines.internal.*
 import kotlin.coroutines.*
-import kotlin.jvm.*
 import kotlin.math.*
-import kotlin.native.concurrent.*
+import kotlin.native.concurrent.SharedImmutable
 
 /**
  * A counting semaphore for coroutines that logically maintains a number of available permits.
@@ -84,16 +83,26 @@
     }
 }
 
-private class SemaphoreImpl(
-    private val permits: Int, acquiredPermits: Int
-) : Semaphore, SegmentQueue<SemaphoreSegment>() {
+private class SemaphoreImpl(private val permits: Int, acquiredPermits: Int) : Semaphore {
+
+    // The queue of waiting acquirers is essentially an infinite array based on the list of segments
+    // (see `SemaphoreSegment`); each segment contains a fixed number of slots. To determine a slot for each enqueue
+    // and dequeue operation, we increment the corresponding counter at the beginning of the operation
+    // and use the value before the increment as a slot number. This way, each enqueue-dequeue pair
+    // works with an individual cell.We use the corresponding segment pointer to find the required ones.
+    private val head: AtomicRef<SemaphoreSegment>
+    private val deqIdx = atomic(0L)
+    private val tail: AtomicRef<SemaphoreSegment>
+    private val enqIdx = atomic(0L)
+
     init {
         require(permits > 0) { "Semaphore should have at least 1 permit, but had $permits" }
         require(acquiredPermits in 0..permits) { "The number of acquired permits should be in 0..$permits" }
+        val s = SemaphoreSegment(0, null, 2)
+        head = atomic(s)
+        tail = atomic(s)
     }
 
-    override fun newSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev)
-
     /**
      * This counter indicates a number of available permits if it is non-negative,
      * or the size with minus sign otherwise. Note, that 32-bit counter is enough here
@@ -104,14 +113,6 @@
     private val _availablePermits = atomic(permits - acquiredPermits)
     override val availablePermits: Int get() = max(_availablePermits.value, 0)
 
-    // The queue of waiting acquirers is essentially an infinite array based on `SegmentQueue`;
-    // each segment contains a fixed number of slots. To determine a slot for each enqueue
-    // and dequeue operation, we increment the corresponding counter at the beginning of the operation
-    // and use the value before the increment as a slot number. This way, each enqueue-dequeue pair
-    // works with an individual cell.
-    private val enqIdx = atomic(0L)
-    private val deqIdx = atomic(0L)
-
     override fun tryAcquire(): Boolean {
         _availablePermits.loop { p ->
             if (p <= 0) return false
@@ -136,12 +137,13 @@
         cur + 1
     }
 
-    private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutineReusable<Unit> sc@ { cont ->
-        val last = this.tail
+    private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutineReusable<Unit> sc@{ cont ->
+        val curTail = this.tail.value
         val enqIdx = enqIdx.getAndIncrement()
-        val segment = getSegment(last, enqIdx / SEGMENT_SIZE)
+        val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail,
+                                                          createNewSegment = ::createSegment).run { segment } // cannot be closed
         val i = (enqIdx % SEGMENT_SIZE).toInt()
-        if (segment === null || segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
+        if (segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
             // already resumed
             cont.resume(Unit)
             return@sc
@@ -151,10 +153,17 @@
 
     @Suppress("UNCHECKED_CAST")
     internal fun resumeNextFromQueue() {
-        try_again@while (true) {
-            val first = this.head
+        try_again@ while (true) {
+            val curHead = this.head.value
             val deqIdx = deqIdx.getAndIncrement()
-            val segment = getSegmentAndMoveHead(first, deqIdx / SEGMENT_SIZE) ?: continue@try_again
+            val id = deqIdx / SEGMENT_SIZE
+            val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
+                                                              createNewSegment = ::createSegment).run { segment } // cannot be closed
+            segment.cleanPrev()
+            if (segment.id > id) {
+                this.deqIdx.updateIfLower(segment.id * SEGMENT_SIZE)
+                continue@try_again
+            }
             val i = (deqIdx % SEGMENT_SIZE).toInt()
             val cont = segment.getAndSet(i, RESUMED)
             if (cont === null) return // just resumed
@@ -165,6 +174,10 @@
     }
 }
 
+private inline fun AtomicLong.updateIfLower(value: Long): Unit = loop { cur ->
+    if (cur >= value || compareAndSet(cur, value)) return
+}
+
 private class CancelSemaphoreAcquisitionHandler(
     private val semaphore: SemaphoreImpl,
     private val segment: SemaphoreSegment,
@@ -180,10 +193,11 @@
     override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $index]"
 }
 
-private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment<SemaphoreSegment>(id, prev) {
+private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)
+
+private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) : Segment<SemaphoreSegment>(id, prev, pointers) {
     val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
-    private val cancelledSlots = atomic(0)
-    override val removed get() = cancelledSlots.value == SEGMENT_SIZE
+    override val maxSlots: Int get() = SEGMENT_SIZE
 
     @Suppress("NOTHING_TO_INLINE")
     inline fun get(index: Int): Any? = acquirers[index].value
@@ -200,8 +214,7 @@
         // Try to cancel the slot
         val cancelled = getAndSet(index, CANCELLED) !== RESUMED
         // Remove this segment if needed
-        if (cancelledSlots.incrementAndGet() == SEGMENT_SIZE)
-            remove()
+        onSlotCleaned()
         return cancelled
     }
 
diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt
index 293be7a..3d1305c 100644
--- a/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt
+++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt
@@ -1,9 +1,13 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
 package kotlinx.coroutines.internal
 
-import kotlinx.atomicfu.atomic
+import kotlinx.atomicfu.*
 
 /**
- * This queue implementation is based on [SegmentQueue] for testing purposes and is organized as follows. Essentially,
+ * This queue implementation is based on [SegmentList] for testing purposes and is organized as follows. Essentially,
  * the [SegmentBasedQueue] is represented as an infinite array of segments, each stores one element (see [OneElementSegment]).
  * Both [enqueue] and [dequeue] operations increment the corresponding global index ([enqIdx] for [enqueue] and
  * [deqIdx] for [dequeue]) and work with the indexed by this counter cell. Since both operations increment the indices
@@ -13,60 +17,89 @@
  * the cell with [BROKEN] token and retry the operation, [enqueue] at the same time should restart as well; this way,
  * the queue is obstruction-free.
  */
-internal class SegmentBasedQueue<T> : SegmentQueue<OneElementSegment<T>>() {
-    override fun newSegment(id: Long, prev: OneElementSegment<T>?): OneElementSegment<T> = OneElementSegment(id, prev)
+internal class SegmentBasedQueue<T> {
+    private val head: AtomicRef<OneElementSegment<T>>
+    private val tail: AtomicRef<OneElementSegment<T>>
 
     private val enqIdx = atomic(0L)
     private val deqIdx = atomic(0L)
 
-    // Returns the segments associated with the enqueued element.
-    fun enqueue(element: T): OneElementSegment<T> {
+    init {
+        val s = OneElementSegment<T>(0, null, 2)
+        head = atomic(s)
+        tail = atomic(s)
+    }
+
+    // Returns the segments associated with the enqueued element, or `null` if the queue is closed.
+    fun enqueue(element: T): OneElementSegment<T>? {
         while (true) {
-            var tail = this.tail
+            val curTail = this.tail.value
             val enqIdx = this.enqIdx.getAndIncrement()
-            tail = getSegment(tail, enqIdx) ?: continue
-            if (tail.element.value === BROKEN) continue
-            if (tail.element.compareAndSet(null, element)) return tail
+            val segmentOrClosed = this.tail.findSegmentAndMoveForward(id = enqIdx, startFrom = curTail, createNewSegment = ::createSegment)
+            if (segmentOrClosed.isClosed) return null
+            val s = segmentOrClosed.segment
+            if (s.element.value === BROKEN) continue
+            if (s.element.compareAndSet(null, element)) return s
         }
     }
 
     fun dequeue(): T? {
         while (true) {
             if (this.deqIdx.value >= this.enqIdx.value) return null
-            var firstSegment = this.head
+            val curHead = this.head.value
             val deqIdx = this.deqIdx.getAndIncrement()
-            firstSegment = getSegmentAndMoveHead(firstSegment, deqIdx) ?: continue
-            var el = firstSegment.element.value
+            val segmentOrClosed = this.head.findSegmentAndMoveForward(id = deqIdx, startFrom = curHead, createNewSegment = ::createSegment)
+            if (segmentOrClosed.isClosed) return null
+            val s = segmentOrClosed.segment
+            s.cleanPrev()
+            if (s.id > deqIdx) continue
+            var el = s.element.value
             if (el === null) {
-                if (firstSegment.element.compareAndSet(null, BROKEN)) continue
-                else el = firstSegment.element.value
+                if (s.element.compareAndSet(null, BROKEN)) continue
+                else el = s.element.value
             }
-            if (el === REMOVED) continue
+            if (el === BROKEN) continue
+            @Suppress("UNCHECKED_CAST")
             return el as T
         }
     }
 
+    // `enqueue` should return `null` after the queue is closed
+    fun close(): OneElementSegment<T> {
+        val s = this.tail.value.close()
+        var cur = s
+        while (true) {
+            cur.element.compareAndSet(null, BROKEN)
+            cur = cur.prev ?: break
+        }
+        return s
+    }
+
     val numberOfSegments: Int get() {
-        var s: OneElementSegment<T>? = head
-        var i = 0
-        while (s != null) {
-            s = s.next
+        var cur = head.value
+        var i = 1
+        while (true) {
+            cur = cur.next ?: return i
             i++
         }
-        return i
+    }
+
+    fun checkHeadPrevIsCleaned() {
+        check(head.value.prev === null)
     }
 }
 
-internal class OneElementSegment<T>(id: Long, prev: OneElementSegment<T>?) : Segment<OneElementSegment<T>>(id, prev) {
+private fun <T> createSegment(id: Long, prev: OneElementSegment<T>?) = OneElementSegment(id, prev, 0)
+
+internal class OneElementSegment<T>(id: Long, prev: OneElementSegment<T>?, pointers: Int) : Segment<OneElementSegment<T>>(id, prev, pointers) {
     val element = atomic<Any?>(null)
 
-    override val removed get() = element.value === REMOVED
+    override val maxSlots get() = 1
 
     fun removeSegment() {
-        element.value = REMOVED
-        remove()
+        element.value = BROKEN
+        onSlotCleaned()
     }
 }
 
-private val BROKEN = Symbol("BROKEN")
-private val REMOVED = Symbol("REMOVED")
\ No newline at end of file
+private val BROKEN = Symbol("BROKEN")
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentListTest.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentListTest.kt
new file mode 100644
index 0000000..ff6a346
--- /dev/null
+++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentListTest.kt
@@ -0,0 +1,41 @@
+package kotlinx.coroutines.internal
+
+import kotlinx.atomicfu.*
+import org.junit.Test
+import kotlin.test.*
+
+class SegmentListTest {
+    @Test
+    fun testRemoveTail() {
+        val initialSegment = TestSegment(0, null, 2)
+        val head = AtomicRefHolder(initialSegment)
+        val tail = AtomicRefHolder(initialSegment)
+        val s1 = tail.ref.findSegmentAndMoveForward(1, tail.ref.value, ::createTestSegment).segment
+        assertFalse(s1.removed)
+        tail.ref.value.onSlotCleaned()
+        assertFalse(s1.removed)
+        head.ref.findSegmentAndMoveForward(2, head.ref.value, ::createTestSegment)
+        assertFalse(s1.removed)
+        tail.ref.findSegmentAndMoveForward(2, head.ref.value, ::createTestSegment)
+        assertTrue(s1.removed)
+    }
+
+    @Test
+    fun testClose() {
+        val initialSegment = TestSegment(0, null, 2)
+        val head = AtomicRefHolder(initialSegment)
+        val tail = AtomicRefHolder(initialSegment)
+        tail.ref.findSegmentAndMoveForward(1, tail.ref.value, ::createTestSegment)
+        assertEquals(tail.ref.value, tail.ref.value.close())
+        assertTrue(head.ref.findSegmentAndMoveForward(2, head.ref.value, ::createTestSegment).isClosed)
+    }
+}
+
+private class AtomicRefHolder<T>(initialValue: T) {
+    val ref = atomic(initialValue)
+}
+
+private class TestSegment(id: Long, prev: TestSegment?, pointers: Int) : Segment<TestSegment>(id, prev, pointers) {
+    override val maxSlots: Int get() = 1
+}
+private fun createTestSegment(id: Long, prev: TestSegment?) = TestSegment(id, prev, 0)
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt
index b59a648..fd2d329 100644
--- a/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt
@@ -33,7 +33,7 @@
         val s = q.enqueue(2)
         q.enqueue(3)
         assertEquals(3, q.numberOfSegments)
-        s.removeSegment()
+        s!!.removeSegment()
         assertEquals(2, q.numberOfSegments)
         assertEquals(1, q.dequeue())
         assertEquals(3, q.dequeue())
@@ -47,12 +47,22 @@
         val s = q.enqueue(2)
         assertEquals(1, q.dequeue())
         q.enqueue(3)
-        s.removeSegment()
+        s!!.removeSegment()
         assertEquals(3, q.dequeue())
         assertNull(q.dequeue())
     }
 
     @Test
+    fun testClose() {
+        val q = SegmentBasedQueue<Int>()
+        q.enqueue(1)
+        assertEquals(0, q.close().id)
+        assertEquals(null, q.enqueue(2))
+        assertEquals(1, q.dequeue())
+        assertEquals(null, q.dequeue())
+    }
+
+    @Test
     fun stressTest() {
         val q = SegmentBasedQueue<Int>()
         val expectedQueue = ArrayDeque<Int>()
@@ -64,6 +74,7 @@
                 expectedQueue.add(el)
             } else { // remove
                 assertEquals(expectedQueue.poll(), q.dequeue())
+                q.checkHeadPrevIsCleaned()
             }
         }
     }
@@ -78,7 +89,7 @@
         val N = 100_000 * stressTestMultiplier
         val T = 10
         val q = SegmentBasedQueue<Int>()
-        val segments = (1..N).map { q.enqueue(it) }.toMutableList()
+        val segments = (1..N).map { q.enqueue(it)!! }.toMutableList()
         if (random) segments.shuffle()
         assertEquals(N, q.numberOfSegments)
         val nextSegmentIndex = AtomicInteger()
diff --git a/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt b/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt
index 1bb51a5..89bf8df 100644
--- a/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/linearizability/SegmentQueueLCStressTest.kt
@@ -18,12 +18,17 @@
     private val q = SegmentBasedQueue<Int>()
 
     @Operation
-    fun add(@Param(name = "value") value: Int) {
-        q.enqueue(value)
+    fun enqueue(@Param(name = "value") x: Int): Boolean {
+        return q.enqueue(x) !== null
     }
 
     @Operation
-    fun poll(): Int? = q.dequeue()
+    fun dequeue(): Int? = q.dequeue()
+
+    @Operation
+    fun close() {
+        q.close()
+    }
 
     override fun extractState(): Any {
         val elements = ArrayList<Int>()
@@ -31,8 +36,8 @@
             val x = q.dequeue() ?: break
             elements.add(x)
         }
-
-        return elements
+        val closed = q.enqueue(0) === null
+        return elements to closed
     }
 
     @Test