EventLoop scheduled tasks impl is rewritten
diff --git a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/EventLoop.kt b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/EventLoop.kt
index afbae4f..2213877 100644
--- a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/EventLoop.kt
+++ b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/EventLoop.kt
@@ -18,7 +18,9 @@
import kotlinx.coroutines.experimental.internal.LockFreeLinkedListHead
import kotlinx.coroutines.experimental.internal.LockFreeLinkedListNode
-import java.util.concurrent.ConcurrentSkipListMap
+import kotlinx.coroutines.experimental.internal.ThreadSafeHeap
+import kotlinx.coroutines.experimental.internal.ThreadSafeHeapNode
+import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.locks.LockSupport
@@ -67,7 +69,7 @@
private val thread: Thread
) : CoroutineDispatcher(), EventLoop, Delay {
private val queue = LockFreeLinkedListHead()
- private val delayed = ConcurrentSkipListMap<DelayedTask, DelayedTask>()
+ private val delayed = ThreadSafeHeap<DelayedTask>()
private val nextSequence = AtomicLong()
private var parentJob: Job? = null
@@ -78,9 +80,10 @@
override fun dispatch(context: CoroutineContext, block: Runnable) {
if (scheduleQueued(QueuedRunnableTask(block))) {
+ // todo: we should unpark only when this task became first in the queue
unpark()
} else {
- block.run()
+ block.run() // otherwise run it right here (as if Unconfined)
}
}
@@ -89,42 +92,56 @@
// todo: we should unpark only when this delayed task became first in the queue
unpark()
} else {
- scheduledExecutor.schedule(ResumeRunnable(continuation), time, unit)
+ scheduledExecutor.schedule(ResumeRunnable(continuation), time, unit) // otherwise reschedule to other time pool
}
}
- override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle =
- DelayedRunnableTask(time, unit, block).also { scheduleDelayed(it) }
+ override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle {
+ val delayedTask = DelayedRunnableTask(time, unit, block)
+ if (scheduleDelayed(delayedTask)) {
+ // todo: we should unpark only when this delayed task became first in the queue
+ unpark()
+ return delayedTask
+ }
+ return DisposableFutureHandle(scheduledExecutor.schedule(block, time, unit))
+ }
override fun processNextEvent(): Long {
if (Thread.currentThread() !== thread) return Long.MAX_VALUE
// queue all delayed tasks that are due to be executed
- while (true) {
- val delayedTask = delayed.firstEntry()?.key ?: break
+ if (!delayed.isEmpty) {
val now = System.nanoTime()
- if (delayedTask.nanoTime - now > 0) break
- if (!scheduleQueued(delayedTask)) break
- delayed.remove(delayedTask)
+ while (true) {
+ val delayedTask = delayed.removeFirstIf { it.timeToExecute(now) } ?: break
+ queue.addLast(delayedTask)
+ }
}
// then process one event from queue
(queue.removeFirstOrNull() as? QueuedTask)?.let { queuedTask ->
- queuedTask()
+ queuedTask.run()
}
if (!queue.isEmpty) return 0
- val nextDelayedTask = delayed.firstEntry()?.key ?: return Long.MAX_VALUE
- return nextDelayedTask.nanoTime - System.nanoTime()
+ val nextDelayedTask = delayed.peek() ?: return Long.MAX_VALUE
+ return (nextDelayedTask.nanoTime - System.nanoTime()).coerceAtLeast(0)
}
+ private val isActive: Boolean get() = parentJob?.isCompleted != true
+
fun shutdown() {
+ assert(!isActive)
// complete processing of all queued tasks
while (true) {
val queuedTask = (queue.removeFirstOrNull() ?: break) as QueuedTask
- queuedTask()
+ queuedTask.run()
}
- // cancel all delayed tasks
+ // reschedule or execute delayed tasks
while (true) {
- val delayedTask = delayed.pollFirstEntry()?.key ?: break
- delayedTask.cancel()
+ val delayedTask = delayed.removeFirst() ?: break
+ val now = System.nanoTime()
+ if (delayedTask.timeToExecute(now))
+ delayedTask.run()
+ else
+ delayedTask.rescheduleOnShutdown(now)
}
}
@@ -133,14 +150,15 @@
queue.addLast(queuedTask)
return true
}
- return queue.addLastIf(queuedTask, { !parentJob!!.isCompleted })
+ return queue.addLastIf(queuedTask) { isActive }
}
private fun scheduleDelayed(delayedTask: DelayedTask): Boolean {
- delayed.put(delayedTask, delayedTask)
- if (parentJob?.isActive != false) return true
- delayedTask.dispose()
- return false
+ if (parentJob == null) {
+ delayed.addLast(delayedTask)
+ return true
+ }
+ return delayed.addLastIf(delayedTask) { isActive }
}
private fun unpark() {
@@ -148,20 +166,22 @@
LockSupport.unpark(thread)
}
- private abstract class QueuedTask : LockFreeLinkedListNode(), () -> Unit
+ private abstract class QueuedTask : LockFreeLinkedListNode(), Runnable
private class QueuedRunnableTask(
private val block: Runnable
) : QueuedTask() {
- override fun invoke() { block.run() }
+ override fun run() { block.run() }
override fun toString(): String = block.toString()
}
private abstract inner class DelayedTask(
time: Long, timeUnit: TimeUnit
- ) : QueuedTask(), Comparable<DelayedTask>, DisposableHandle {
+ ) : QueuedTask(), Comparable<DelayedTask>, DisposableHandle, ThreadSafeHeapNode {
+ override var index: Int = -1
@JvmField val nanoTime: Long = System.nanoTime() + timeUnit.toNanos(time)
@JvmField val sequence: Long = nextSequence.getAndIncrement()
+ private var scheduledAfterShutdown: Future<*>? = null
override fun compareTo(other: DelayedTask): Int {
val dTime = nanoTime - other.nanoTime
@@ -171,12 +191,21 @@
return if (dSequence > 0) 1 else if (dSequence < 0) -1 else 0
}
- override final fun dispose() {
- delayed.remove(this)
- cancel()
+ fun timeToExecute(now: Long): Boolean = now - nanoTime >= 0L
+
+ fun rescheduleOnShutdown(now: Long) = synchronized(delayed) {
+ if (delayed.remove(this)) {
+ assert (scheduledAfterShutdown == null)
+ scheduledAfterShutdown = scheduledExecutor.schedule(this, nanoTime - now, TimeUnit.NANOSECONDS)
+ }
}
- open fun cancel() {}
+ override final fun dispose() = synchronized(delayed) {
+ if (!delayed.remove(this)) {
+ scheduledAfterShutdown?.cancel(false)
+ scheduledAfterShutdown = null
+ }
+ }
override fun toString(): String = "Delayed[nanos=$nanoTime,seq=$sequence]"
}
@@ -185,21 +214,16 @@
time: Long, timeUnit: TimeUnit,
private val cont: CancellableContinuation<Unit>
) : DelayedTask(time, timeUnit) {
- override fun invoke() {
+ override fun run() {
with(cont) { resumeUndispatched(Unit) }
}
- override fun cancel() {
- if (!cont.isActive) return
- val remaining = nanoTime - System.nanoTime()
- scheduledExecutor.schedule(ResumeRunnable(cont), remaining, TimeUnit.NANOSECONDS)
- }
}
private inner class DelayedRunnableTask(
time: Long, timeUnit: TimeUnit,
private val block: Runnable
) : DelayedTask(time, timeUnit) {
- override fun invoke() { block.run() }
+ override fun run() { block.run() }
override fun toString(): String = super.toString() + block.toString()
}
}
diff --git a/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt
new file mode 100644
index 0000000..8d369a0
--- /dev/null
+++ b/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.internal
+
+/**
+ * @suppress **This is unstable API and it is subject to change.**
+ */
+public interface ThreadSafeHeapNode {
+ public var index: Int
+}
+
+/**
+ * Synchronized binary heap.
+ *
+ * @suppress **This is unstable API and it is subject to change.**
+ */
+public class ThreadSafeHeap<T> where T: ThreadSafeHeapNode, T: Comparable<T> {
+ private var a: Array<T?>? = null
+
+ @JvmField @PublishedApi @Volatile
+ internal var size = 0
+
+ public val isEmpty: Boolean get() = size == 0
+
+ public fun peek(): T? = synchronized(this) { firstImpl() }
+
+ public fun removeFirst(): T? = synchronized(this) {
+ if (size > 0) {
+ removeAtImpl(0)
+ } else
+ null
+ }
+
+ public inline fun removeFirstIf(predicate: (T) -> Boolean): T? = synchronized(this) {
+ val first = firstImpl() ?: return@synchronized null
+ if (predicate(first)) {
+ removeAtImpl(0)
+ } else
+ null
+ }
+
+ public fun addLast(node: T) = synchronized(this) {
+ addImpl(node)
+ }
+
+ public fun addLastIf(node: T, cond: () -> Boolean): Boolean = synchronized(this) {
+ if (cond()) {
+ addImpl(node)
+ true
+ } else
+ false
+ }
+
+ public fun remove(node: T): Boolean = synchronized(this) {
+ if (node.index < 0) {
+ false
+ } else {
+ removeAtImpl(node.index)
+ true
+ }
+ }
+
+ @PublishedApi
+ internal fun firstImpl(): T? = a?.get(0)
+
+ @PublishedApi
+ internal fun removeAtImpl(index: Int): T {
+ check(size > 0)
+ val a = this.a!!
+ size--
+ if (index < size) {
+ swap(index, size)
+ var i = index
+ while (true) {
+ var j = 2 * i + 1
+ if (j >= size) break
+ if (j + 1 < size && a[j + 1]!! < a[j]!!) j++
+ if (a[i]!! <= a[j]!!) break
+ swap(i, j)
+ i = j
+ }
+ }
+ val result = a[size]!!
+ result.index = -1
+ a[size] = null
+ return result
+ }
+
+ @PublishedApi
+ internal fun addImpl(node: T) {
+ val a = realloc()
+ var i = size++
+ a[i] = node
+ node.index = i
+ while (i > 0) {
+ val j = (i - 1) / 2
+ if (a[j]!! <= a[i]!!) break
+ swap(i, j)
+ i = j
+ }
+ }
+
+
+ @Suppress("UNCHECKED_CAST")
+ private fun realloc(): Array<T?> {
+ val a = this.a
+ return when {
+ a == null -> (arrayOfNulls<ThreadSafeHeapNode>(4) as Array<T?>).also { this.a = it }
+ size >= a.size -> a.copyOf(size * 2).also { this.a = it }
+ else -> a
+ }
+ }
+
+ private fun swap(i: Int, j: Int) {
+ val a = a!!
+ val ni = a[j]!!
+ val nj = a[i]!!
+ a[i] = ni
+ a[j] = nj
+ ni.index = i
+ nj.index = j
+ }
+}
\ No newline at end of file
diff --git a/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt
new file mode 100644
index 0000000..91257e1
--- /dev/null
+++ b/kotlinx-coroutines-core/src/test/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeapTest.kt
@@ -0,0 +1,77 @@
+/*
+ * Copyright 2016-2017 JetBrains s.r.o.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package kotlinx.coroutines.experimental.internal
+
+import org.hamcrest.core.IsEqual
+import org.hamcrest.core.IsNull
+import org.junit.Assert.assertThat
+import org.junit.Test
+import java.util.*
+
+class ThreadSafeHeapTest {
+ class Node(val value: Int) : ThreadSafeHeapNode, Comparable<Node> {
+ override var index = -1
+ override fun compareTo(other: Node): Int = Integer.compare(value, other.value)
+ override fun equals(other: Any?): Boolean = other is Node && other.value == value
+ override fun hashCode(): Int = value
+ override fun toString(): String = "$value"
+ }
+
+ @Test
+ fun testBasic() {
+ val h = ThreadSafeHeap<Node>()
+ assertThat(h.peek(), IsNull())
+ val n1 = Node(1)
+ h.addLast(n1)
+ assertThat(h.peek(), IsEqual(n1))
+ val n2 = Node(2)
+ h.addLast(n2)
+ assertThat(h.peek(), IsEqual(n1))
+ val n3 = Node(3)
+ h.addLast(n3)
+ assertThat(h.peek(), IsEqual(n1))
+ val n4 = Node(4)
+ h.addLast(n4)
+ assertThat(h.peek(), IsEqual(n1))
+ val n5 = Node(5)
+ h.addLast(n5)
+ assertThat(h.peek(), IsEqual(n1))
+ assertThat(h.removeFirst(), IsEqual(n1))
+ assertThat(n1.index, IsEqual(-1))
+ assertThat(h.peek(), IsEqual(n2))
+ h.remove(n2)
+ assertThat(h.peek(), IsEqual(n3))
+ h.remove(n4)
+ assertThat(h.peek(), IsEqual(n3))
+ h.remove(n3)
+ assertThat(h.peek(), IsEqual(n5))
+ h.remove(n5)
+ assertThat(h.peek(), IsNull())
+ }
+
+ @Test
+ fun testRandomSort() {
+ val n = 1000
+ val r = Random(1)
+ val h = ThreadSafeHeap<Node>()
+ val a = IntArray(n) { r.nextInt() }
+ repeat(n) { h.addLast(Node(a[it])) }
+ a.sort()
+ repeat(n) { assertThat(h.removeFirst(), IsEqual(Node(a[it]))) }
+ assertThat(h.peek(), IsNull())
+ }
+}
\ No newline at end of file