Fixed removal of arbitrary nodes from ThreadSafeHeap,
previously heap invariant could have been violated because of non-first removal.
diff --git a/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt b/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt
index 3877168..3b239f7 100644
--- a/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt
+++ b/core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt
@@ -84,14 +84,12 @@
         size--
         if (index < size) {
             swap(index, size)
-            var i = index
-            while (true) {
-                var j = 2 * i + 1
-                if (j >= size) break
-                if (j + 1 < size && a[j + 1]!! < a[j]!!) j++
-                if (a[i]!! <= a[j]!!) break
-                swap(i, j)
-                i = j
+            val j = (index - 1) / 2
+            if (index > 0 && a[index]!! < a[j]!!) {
+                swap(index, j)
+                siftUpFrom(j)
+            } else {
+                siftDownFrom(index)
             }
         }
         val result = a[size]!!
@@ -106,14 +104,27 @@
         var i = size++
         a[i] = node
         node.index = i
-        while (i > 0) {
-            val j = (i - 1) / 2
-            if (a[j]!! <= a[i]!!) break
-            swap(i, j)
-            i = j
-        }
+        siftUpFrom(i)
     }
 
+    private tailrec fun siftUpFrom(i: Int) {
+        if (i <= 0) return
+        val a = a!!
+        val j = (i - 1) / 2
+        if (a[j]!! <= a[i]!!) return
+        swap(i, j)
+        siftUpFrom(j)
+    }
+
+    private tailrec fun siftDownFrom(i: Int) {
+        var j = 2 * i + 1
+        if (j >= size) return
+        val a = a!!
+        if (j + 1 < size && a[j + 1]!! < a[j]!!) j++
+        if (a[i]!! <= a[j]!!) return
+        swap(i, j)
+        siftDownFrom(j)
+    }
 
     @Suppress("UNCHECKED_CAST")
     private fun realloc(): Array<T?> {
diff --git a/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt
index 6355acb..de9f60d 100644
--- a/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt
+++ b/core/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt
@@ -16,10 +16,11 @@
 
 package kotlinx.coroutines.experimental.internal
 
+import kotlinx.coroutines.experimental.*
 import kotlin.test.*
 import java.util.*
 
-class ThreadSafeHeapTest {
+class ThreadSafeHeapTest : TestBase() {
     class Node(val value: Int) : ThreadSafeHeapNode, Comparable<Node> {
         override var index = -1
         override fun compareTo(other: Node): Int = value.compareTo(other.value)
@@ -62,7 +63,7 @@
 
     @Test
     fun testRandomSort() {
-        val n = 1000
+        val n = 1000 * stressTestMultiplier
         val r = Random(1)
         val h = ThreadSafeHeap<Node>()
         val a = IntArray(n) { r.nextInt() }
@@ -71,4 +72,36 @@
         repeat(n) { assertEquals(Node(a[it]), h.removeFirstOrNull()) }
         assertEquals(null, h.peek())
     }
+
+    @Test
+    fun testRandomRemove() {
+        val n = 1000 * stressTestMultiplier
+        check(n % 2 == 0) { "Must be even" }
+        val r = Random(1)
+        val h = ThreadSafeHeap<Node>()
+        val set = TreeSet<Node>()
+        repeat(n) {
+            val node = Node(r.nextInt())
+            h.addLast(node)
+            assertTrue(set.add(node))
+        }
+        while (!h.isEmpty) {
+            // pick random node to remove
+            val rndNode: Node
+            while (true) {
+                val tail = set.tailSet(Node(r.nextInt()))
+                if (!tail.isEmpty()) {
+                    rndNode = tail.first()
+                    break
+                }
+            }
+            assertTrue(set.remove(rndNode))
+            assertTrue(h.remove(rndNode))
+            // remove head and validate
+            val headNode = h.removeFirstOrNull()!! // must not be null!!!
+            assertTrue(headNode === set.first(), "Expected ${set.first()}, but found $headNode, remaining size ${h.size}")
+            assertTrue(set.remove(headNode))
+            assertEquals(set.size, h.size)
+        }
+    }
 }
\ No newline at end of file