Multi-part atomic remove operation support for LockFreeLinkedList
diff --git a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/LockFreeLinkedList.kt b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/LockFreeLinkedList.kt
index f6f14e4..10f95ee 100644
--- a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/LockFreeLinkedList.kt
+++ b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/LockFreeLinkedList.kt
@@ -16,7 +16,6 @@
 
 package kotlinx.coroutines.experimental.internal
 
-import java.util.concurrent.atomic.AtomicIntegerFieldUpdater
 import java.util.concurrent.atomic.AtomicReferenceFieldUpdater
 
 private typealias Node = LockFreeLinkedListNode
@@ -30,24 +29,45 @@
 @PublishedApi
 internal const val FAILURE = 2
 
+@PublishedApi
+internal val CONDITION_FALSE: Any = Symbol("CONDITION_FALSE")
+
+@PublishedApi
+internal val ALREADY_REMOVED: Any = Symbol("ALREADY_REMOVED")
+
+@PublishedApi
+internal val LIST_EMPTY: Any = Symbol("LIST_EMPTY")
+
+private val REMOVE_PREPARED: Any = Symbol("REMOVE_PREPARED")
+
+/**
+ * @suppress **This is unstable API and it is subject to change.**
+ */
+public typealias RemoveFirstDesc<T> = LockFreeLinkedListNode.RemoveFirstDesc<T>
+
 /**
  * Doubly-linked concurrent list node with remove support.
  * Based on paper
  * ["Lock-Free and Practical Doubly Linked List-Based Deques Using Single-Word Compare-and-Swap"](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.140.4693&rep=rep1&type=pdf)
  * by Sundell and Tsigas.
- * The instance of this class serves both as list head/tail sentinel and as the list item.
- * Sentinel node should be never removed.
+ *
+ * Important notes:
+ * * The instance of this class serves both as list head/tail sentinel and as the list item.
+ *   Sentinel node should be never removed.
+ * * There are no operations to add items to left side of the list, only to the end (right side), because we cannot
+ *   efficiently linearize them with atomic multi-step head-removal operations. In short,
+ *   support for [describeRemoveFirst] operation precludes ability to add items at the beginning.
  *
  * @suppress **This is unstable API and it is subject to change.**
  */
 @Suppress("LeakingThis")
 public open class LockFreeLinkedListNode {
     @Volatile
-    private var _next: Any = this // DoubleLinkedNode | Removed | CondAdd
+    private var _next: Any = this // DoubleLinkedNode | Removed | OpDescriptor
     @Volatile
-    private var prev: Any = this // DoubleLinkedNode | Removed
+    private var _prev: Any = this // DoubleLinkedNode | Removed
     @Volatile
-    private var removedRef: Removed? = null // lazily cached removed ref to this
+    private var _removedRef: Removed? = null // lazily cached removed ref to this
 
     private companion object {
         @JvmStatic
@@ -55,114 +75,70 @@
                 AtomicReferenceFieldUpdater.newUpdater(Node::class.java, Any::class.java, "_next")
         @JvmStatic
         val PREV: AtomicReferenceFieldUpdater<Node, Any> =
-                AtomicReferenceFieldUpdater.newUpdater(Node::class.java, Any::class.java, "prev")
+                AtomicReferenceFieldUpdater.newUpdater(Node::class.java, Any::class.java, "_prev")
         @JvmStatic
         val REMOVED_REF: AtomicReferenceFieldUpdater<Node, Removed?> =
-            AtomicReferenceFieldUpdater.newUpdater(Node::class.java, Removed::class.java, "removedRef")
-    }
-
-    private class Removed(val ref: Node) {
-        override fun toString(): String = "Removed[$ref]"
+            AtomicReferenceFieldUpdater.newUpdater(Node::class.java, Removed::class.java, "_removedRef")
     }
 
     private fun removed(): Removed =
-        removedRef ?: Removed(this).also { REMOVED_REF.lazySet(this, it) }
+        _removedRef ?: Removed(this).also { REMOVED_REF.lazySet(this, it) }
 
     @PublishedApi
-    internal abstract class CondAdd(val newNode: Node) {
+    internal abstract class CondAddOp(val newNode: Node) : AtomicOp() {
         lateinit var oldNext: Node
-        @Volatile
-        private var consensus: Int = UNDECIDED // status of operation
 
-        abstract fun isCondition(): Boolean
-
-        private companion object {
-            @JvmStatic
-            val CONSENSUS: AtomicIntegerFieldUpdater<CondAdd> =
-                    AtomicIntegerFieldUpdater.newUpdater(CondAdd::class.java, "consensus")
-        }
-
-        // returns either SUCCESS or FAILURE
-        fun completeAdd(node: Node): Int {
-            // make decision on status
-            var consensus: Int
-            while (true) {
-                consensus = this.consensus
-                if (consensus != UNDECIDED) break
-                val proposal = if (isCondition()) SUCCESS else FAILURE
-                if (CONSENSUS.compareAndSet(this, UNDECIDED, proposal)) {
-                    consensus = proposal
-                    break
-                }
-            }
-            val success = consensus == SUCCESS
-            if (NEXT.compareAndSet(node, this, if (success) newNode else oldNext)) {
+        override fun complete(affected: Any?, failure: Any?) {
+            affected as Node // type assertion
+            val success = failure == null
+            val update = if (success) newNode else oldNext
+            if (NEXT.compareAndSet(affected, this, update)) {
                 // only the thread the makes this update actually finishes add operation
                 if (success) newNode.finishAdd(oldNext)
             }
-            return consensus
         }
     }
 
     @PublishedApi
-    internal inline fun makeCondAdd(node: Node, crossinline condition: () -> Boolean): CondAdd = object : CondAdd(node) {
-        override fun isCondition(): Boolean = condition()
-    }
+    internal inline fun makeCondAddOp(node: Node, crossinline condition: () -> Boolean): CondAddOp =
+        object : CondAddOp(node) {
+            override fun prepare(): Any? = if (condition()) null else CONDITION_FALSE
+        }
 
-    public val isRemoved: Boolean get() = _next is Removed
+    public val isRemoved: Boolean get() = next is Removed
 
-    private val isFresh: Boolean get() = _next === this && prev === this
-
-    @PublishedApi
-    internal val next: Any get() {
-        while (true) { // helper loop on _next
+    // LINEARIZABLE.
+    public val next: Any get() {
+        while (true) { // operation helper loop on _next
             val next = this._next
-            if (next !is CondAdd) return next
-            next.completeAdd(this)
+            if (next !is OpDescriptor) return next
+            next.perform(this)
         }
     }
 
-    public fun next(): Node = next.unwrap()
-
-    public fun prev(): Node {
+    // LINEARIZABLE. Note: use it on sentinel (never removed) node only
+    public val prev: Node get() {
         while (true) {
-            prevHelper()?.let { return it.unwrap() }
+            val prev = this._prev as Node // this sentinel node is never removed
+            if (prev.next === this) return prev
+            helpInsert(prev)
         }
     }
 
-    // ------ addFirstXXX ------
+    // ------ addOneIfEmpty ------
 
-    /**
-     * Adds first item to this list.
-     */
-    public fun addFirst(node: Node) {
-        while (true) { // lock-free loop on next
-            val next = this.next as Node // this sentinel node is never removed
-            if (addNext(node, next)) return
-        }
-    }
-
-    /**
-     * Adds first item to this list atomically if the [condition] is true.
-     */
-    public inline fun addFirstIf(node: Node, crossinline condition: () -> Boolean): Boolean {
-        val condAdd = makeCondAdd(node, condition)
-        while (true) { // lock-free loop on next
-            val next = this.next as Node // this sentinel node is never removed
-            when (tryCondAddNext(node, next, condAdd)) {
-                SUCCESS -> return true
-                FAILURE -> return false
-            }
-        }
-    }
-
-    public fun addFirstIfEmpty(node: Node): Boolean {
+    public fun addOneIfEmpty(node: Node): Boolean {
         PREV.lazySet(node, this)
         NEXT.lazySet(node, this)
-        if (!NEXT.compareAndSet(this, this, node)) return false // this is not an empty list!
-        // added successfully (linearized add) -- fixup the list
-        node.finishAdd(this)
-        return true
+        while (true) {
+            val next = next
+            if (next !== this) return false // this is not an empty list!
+            if (NEXT.compareAndSet(this, this, node)) {
+                // added successfully (linearized add) -- fixup the list
+                node.finishAdd(this)
+                return true
+            }
+        }
     }
 
     // ------ addLastXXX ------
@@ -172,7 +148,7 @@
      */
     public fun addLast(node: Node) {
         while (true) { // lock-free loop on prev.next
-            val prev = prevHelper() ?: continue
+            val prev = prev
             if (prev.addNext(node, this)) return
         }
     }
@@ -181,9 +157,9 @@
      * Adds last item to this list atomically if the [condition] is true.
      */
     public inline fun addLastIf(node: Node, crossinline condition: () -> Boolean): Boolean {
-        val condAdd = makeCondAdd(node, condition)
+        val condAdd = makeCondAddOp(node, condition)
         while (true) { // lock-free loop on prev.next
-            val prev = prevHelper() ?: continue
+            val prev = prev
             when (prev.tryCondAddNext(node, this, condAdd)) {
                 SUCCESS -> return true
                 FAILURE -> return false
@@ -193,7 +169,7 @@
 
     public inline fun addLastIfPrev(node: Node, predicate: (Node) -> Boolean): Boolean {
         while (true) { // lock-free loop on prev.next
-            val prev = prevHelper() ?: continue
+            val prev = prev
             if (!predicate(prev)) return false
             if (prev.addNext(node, this)) return true
         }
@@ -204,9 +180,9 @@
             predicate: (Node) -> Boolean, // prev node predicate
             crossinline condition: () -> Boolean // atomically checked condition
     ): Boolean {
-        val condAdd = makeCondAdd(node, condition)
+        val condAdd = makeCondAddOp(node, condition)
         while (true) { // lock-free loop on prev.next
-            val prev = prevHelper() ?: continue
+            val prev = prev
             if (!predicate(prev)) return false
             when (prev.tryCondAddNext(node, this, condAdd)) {
                 SUCCESS -> return true
@@ -215,14 +191,6 @@
         }
     }
 
-    @PublishedApi
-    internal fun prevHelper(): Node? {
-        val prev = this.prev as Node // this sentinel node is never removed
-        if (prev.next === this) return prev
-        helpInsert(prev)
-        return null
-    }
-
     // ------ addXXX util ------
 
     @PublishedApi
@@ -237,13 +205,13 @@
 
     // returns UNDECIDED, SUCCESS or FAILURE
     @PublishedApi
-    internal fun tryCondAddNext(node: Node, next: Node, condAdd: CondAdd): Int {
+    internal fun tryCondAddNext(node: Node, next: Node, condAdd: CondAddOp): Int {
         PREV.lazySet(node, this)
         NEXT.lazySet(node, next)
         condAdd.oldNext = next
         if (!NEXT.compareAndSet(this, next, condAdd)) return UNDECIDED
         // added operation successfully (linearized) -- complete it & fixup the list
-        return condAdd.completeAdd(this)
+        return if (condAdd.perform(this) == null) SUCCESS else FAILURE
     }
 
     // ------ removeXXX ------
@@ -255,40 +223,165 @@
         while (true) { // lock-free loop on next
             val next = this.next
             if (next is Removed) return false // was already removed -- don't try to help (original thread will take care)
+            check(next !== this) // sanity check -- can be true for sentinel nodes only, but they are never removed
             if (NEXT.compareAndSet(this, next, (next as Node).removed())) {
                 // was removed successfully (linearized remove) -- fixup the list
-                helpDelete()
-                next.helpInsert(prev.unwrap())
+                finishRemove(next)
                 return true
             }
         }
     }
 
-    public fun removeFirstOrNull(): Node? {
-        while (true) { // try to linearize
-            val first = next()
-            if (first == this) return null
-            if (first.remove()) return first
+    public open fun describeRemove() : AtomicDesc? {
+        if (isRemoved) return null // fast path if was already removed
+        return object : AbstractAtomicDesc() {
+            override val affectedNode: Node? get() = this@LockFreeLinkedListNode
+            override var originalNext: Node? = null
+            override fun failure(affected: Node, next: Any): Any? =
+                if (next is Removed) ALREADY_REMOVED else null
+            override fun onPrepare(affected: Node, next: Node): Boolean {
+                originalNext = next
+                return true
+            }
+            override fun updatedNext(next: Node) = next.removed()
+            override fun finishOnSuccess(affected: Node, next: Node) = finishRemove(next)
         }
     }
 
+    public fun removeFirstOrNull(): Node? {
+        while (true) { // try to linearize
+            val first = next as Node
+            if (first === this) return null
+            if (first.remove()) return first
+            first.helpDelete() // must help delete, or loose lock-freedom
+        }
+    }
+
+    public fun describeRemoveFirst(): RemoveFirstDesc<Node> = RemoveFirstDesc(this)
+
     public inline fun <reified T> removeFirstIfIsInstanceOf(): T? {
         while (true) { // try to linearize
-            val first = next()
-            if (first == this) return null
+            val first = next as Node
+            if (first === this) return null
             if (first !is T) return null
             if (first.remove()) return first
+            first.helpDelete() // must help delete, or loose lock-freedom
         }
     }
 
     // just peek at item when predicate is true
     public inline fun <reified T> removeFirstIfIsInstanceOfOrPeekIf(predicate: (T) -> Boolean): T? {
         while (true) { // try to linearize
-            val first = next()
-            if (first == this) return null
+            val first = next as Node
+            if (first === this) return null
             if (first !is T) return null
             if (predicate(first)) return first // just peek when predicate is true
             if (first.remove()) return first
+            first.helpDelete() // must help delete, or loose lock-freedom
+        }
+    }
+
+    // ------ multi-word atomic operations helpers ------
+
+    public open class RemoveFirstDesc<T>(val queue: Node) : AbstractAtomicDesc() {
+        @Suppress("UNCHECKED_CAST")
+        public val result: T get() = affectedNode!! as T
+
+        final override fun takeAffectedNode(): Node = queue.next as Node
+        final override var affectedNode: Node? = null
+        final override var originalNext: Node? = null
+
+        // check node predicates here, must signal failure if affect is not of type T
+        protected override fun failure(affected: Node, next: Any): Any? =
+                if (affected === queue) LIST_EMPTY else null
+        // validate the resulting node (return false if it should be deleted)
+        protected open fun validatePrepared(node: T): Boolean = true // false means remove node & retry
+
+        final override fun retry(affected: Node, next: Any): Boolean {
+            if (next !is Removed) return false
+            affected.helpDelete() // must help delete, or loose lock-freedom
+            return true
+        }
+        @Suppress("UNCHECKED_CAST")
+        final override fun onPrepare(affected: Node, next: Node): Boolean {
+            check(affected !is LockFreeLinkedListHead)
+            if (!validatePrepared(affected as T)) return false
+            affectedNode = affected
+            originalNext = next
+            return true
+        }
+        final override fun updatedNext(next: Node): Any = next.removed()
+        final override fun finishOnSuccess(affected: Node, next: Node) = affected.finishRemove(next)
+    }
+
+    public abstract class AbstractAtomicDesc : AtomicDesc() {
+        protected abstract val affectedNode: Node?
+        protected abstract val originalNext: Node?
+        protected open fun takeAffectedNode(): Node = affectedNode!!
+        protected open fun failure(affected: Node, next: Any): Any? = null // next: Node | Removed
+        protected open fun retry(affected: Node, next: Any): Boolean = false // next: Node | Removed
+        protected abstract fun onPrepare(affected: Node, next: Node): Boolean // false means: remove node & retry
+        protected abstract fun updatedNext(next: Node): Any
+        protected abstract fun finishOnSuccess(affected: Node, next: Node)
+
+        // This is Harris's RDCSS (Restricted Double-Compare Single Swap) operation
+        // It inserts "op" descriptor of when "op" status is still undecided (rolls back otherwise)
+        private class PrepareOp(
+            val next: Node,
+            val op: AtomicOp,
+            val desc: AbstractAtomicDesc
+        ) : OpDescriptor() {
+            override fun perform(affected: Any?): Any? {
+                affected as Node // type assertion
+                if (!desc.onPrepare(affected, next)) return REMOVE_PREPARED
+                check(desc.affectedNode === affected)
+                check(desc.originalNext === next)
+                val update: Any = if (op.isDecided) next else op // restore if decision was already reached
+                NEXT.compareAndSet(affected, this, update)
+                return null // ok
+            }
+        }
+
+        final override fun prepare(op: AtomicOp): Any? {
+            while (true) { // lock free loop on next
+                val affected = takeAffectedNode()
+                // read its original next pointer first
+                val next = affected._next
+                // then see if already reached consensus on overall operation
+                if (op.isDecided) return null // already decided -- go to next desc
+                if (next === op) return null // already in process of operation -- all is good
+                if (next is OpDescriptor) {
+                    // some other operation is in process -- help it
+                    next.perform(affected)
+                    continue // and retry
+                }
+                // next: Node | Removed
+                val failure = failure(affected, next)
+                if (failure != null) return failure // signal failure
+                if (retry(affected, next)) continue // retry operation
+                val prepareOp = PrepareOp(next as Node, op, this)
+                if (NEXT.compareAndSet(affected, next, prepareOp)) {
+                    // prepared -- complete preparations
+                    val prepFail = prepareOp.perform(affected) ?: return null // prepared successfully
+                    check(prepFail === REMOVE_PREPARED) // the only way for prepare to fail
+                    if (NEXT.compareAndSet(affected, prepareOp, next.removed())) {
+                        affected.helpDelete()
+                    }
+                }
+            }
+        }
+
+        final override fun complete(op: AtomicOp, failure: Any?) {
+            val success = failure == null
+            val update = if (success) updatedNext(originalNext!!) else originalNext
+            val affectedNode = affectedNode
+            if (affectedNode == null) {
+                check(!success)
+                return
+            }
+            if (NEXT.compareAndSet(affectedNode, op, update)) {
+                if (success) finishOnSuccess(affectedNode, originalNext!!)
+            }
         }
     }
 
@@ -296,7 +389,7 @@
 
     private fun finishAdd(next: Node) {
         while (true) {
-            val nextPrev = next.prev
+            val nextPrev = next._prev
             if (nextPrev is Removed || this.next !== next) return // next was removed, remover fixes up links
             if (PREV.compareAndSet(next, nextPrev, this)) {
                 if (this.next is Removed) {
@@ -308,16 +401,22 @@
         }
     }
 
+    private fun finishRemove(next: Node) {
+        helpDelete()
+        next.helpInsert(_prev.unwrap())
+    }
+
     private fun markPrev(): Node {
         while (true) { // lock-free loop on prev
-            val prev = this.prev
+            val prev = this._prev
             if (prev is Removed) return prev.ref
             if (PREV.compareAndSet(this, prev, (prev as Node).removed())) return prev
         }
     }
 
     // fixes next links to the left of this node
-    private fun helpDelete() {
+    @PublishedApi
+    internal fun helpDelete() {
         var last: Node? = null // will set to the node left of prev when found
         var prev: Node = markPrev()
         var next: Node = (this._next as Removed).ref
@@ -338,7 +437,7 @@
                     prev = last
                     last = null
                 } else {
-                    prev = prev.prev.unwrap()
+                    prev = prev._prev.unwrap()
                 }
                 continue
             }
@@ -368,11 +467,11 @@
                     prev = last
                     last = null
                 } else {
-                    prev = prev.prev.unwrap()
+                    prev = prev._prev.unwrap()
                 }
                 continue
             }
-            val oldPrev = this.prev
+            val oldPrev = this._prev
             if (oldPrev is Removed) return // this node was removed, too -- its remover will take care
             if (prevNext !== this) {
                 // need to fixup next
@@ -382,50 +481,56 @@
             }
             if (oldPrev === prev) return // it is already linked as needed
             if (PREV.compareAndSet(this, oldPrev, prev)) {
-                if (prev.prev !is Removed) return // finish only if prev was not concurrently removed
+                if (prev._prev !is Removed) return // finish only if prev was not concurrently removed
             }
         }
     }
 
-    private fun Any.unwrap(): Node = if (this is Removed) ref else this as Node
-
     internal fun validateNode(prev: Node, next: Node) {
-        check(prev === this.prev)
-        check(next === this.next)
+        check(prev === this._prev)
+        check(next === this._next)
     }
 }
 
+private class Removed(val ref: Node) {
+    override fun toString(): String = "Removed[$ref]"
+}
+
+@PublishedApi
+internal fun Any.unwrap(): Node = if (this is Removed) ref else this as Node
+
 /**
  * Head (sentinel) item of the linked list that is never removed.
  *
  * @suppress **This is unstable API and it is subject to change.**
  */
 public open class LockFreeLinkedListHead : LockFreeLinkedListNode() {
-    public val isEmpty: Boolean get() = next() == this
+    public val isEmpty: Boolean get() = next === this
 
     /**
      * Iterates over all elements in this list of a specified type.
      */
     public inline fun <reified T : Node> forEach(block: (T) -> Unit) {
-        var cur: Node = next()
+        var cur: Node = next as Node
         while (cur != this) {
             if (cur is T) block(cur)
-            cur = cur.next()
+            cur = cur.next.unwrap()
         }
     }
 
     // just a defensive programming -- makes sure that list head sentinel is never removed
     public final override fun remove() = throw UnsupportedOperationException()
+    public final override fun describeRemove(): AtomicDesc? = throw UnsupportedOperationException()
 
     internal fun validate() {
         var prev: Node = this
-        var cur: Node = next()
+        var cur: Node = next as Node
         while (cur != this) {
-            val next = cur.next()
+            val next = cur.next.unwrap()
             cur.validateNode(prev, next)
             prev = cur
             cur = next
         }
-        validateNode(prev, next())
+        validateNode(prev, next as Node)
     }
 }