blob: 190efe64ef99939bf6ea2038dd9ee584e07eda5f [file] [log] [blame]
Anton Spaansed17bc12018-03-20 18:28:53 -04001/*
2 * Copyright 2016-2018 JetBrains s.r.o.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package kotlinx.coroutines.experimental.test
18
19import kotlinx.coroutines.experimental.*
20import kotlinx.coroutines.experimental.internal.*
21import java.util.concurrent.TimeUnit
22import kotlin.coroutines.experimental.*
23
24/**
25 * This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
26 * code, especially tests, that deal with delays and timeouts in Coroutines.
27 *
28 * Provide an instance of this TestCoroutineContext when calling the *non-blocking* [launch] or [async]
29 * and then advance time or trigger the actions to make the co-routines execute as soon as possible.
30 *
31 * This works much like the *TestScheduler* in RxJava2, which allows to speed up tests that deal
32 * with non-blocking Rx chains that contain delays, timeouts, intervals and such.
33 *
34 * This dispatcher can also handle *blocking* coroutines that are started by [runBlocking].
35 * This dispatcher's virtual time will be automatically advanced based based on the delayed actions
36 * within the Coroutine(s).
37 *
38 * @param name A user-readable name for debugging purposes.
39 */
40class TestCoroutineContext(private val name: String? = null) : CoroutineContext {
41 private val uncaughtExceptions = mutableListOf<Throwable>()
42
43 private val ctxDispatcher = Dispatcher()
44
45 private val ctxHandler = CoroutineExceptionHandler { _, exception ->
46 uncaughtExceptions += exception
47 }
48
49 // The ordered queue for the runnable tasks.
50 private val queue = ThreadSafeHeap<TimedRunnable>()
51
52 // The per-scheduler global order counter.
53 private var counter = 0L
54
55 // Storing time in nanoseconds internally.
56 private var time = 0L
57
58 /**
59 * Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
60 */
61 public val exceptions: List<Throwable> get() = uncaughtExceptions
62
63 // -- CoroutineContext implementation
64
65 public override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R =
66 operation(operation(initial, ctxDispatcher), ctxHandler)
67
68 @Suppress("UNCHECKED_CAST")
69 public override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = when {
70 key === ContinuationInterceptor -> ctxDispatcher as E
71 key === CoroutineExceptionHandler -> ctxHandler as E
72 else -> null
73 }
74
75 public override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext = when {
76 key === ContinuationInterceptor -> ctxHandler
77 key === CoroutineExceptionHandler -> ctxDispatcher
78 else -> this
79 }
80
81 /**
82 * Returns the current virtual clock-time as it is known to this CoroutineContext.
83 *
84 * @param unit The [TimeUnit] in which the clock-time must be returned.
85 * @return The virtual clock-time
86 */
87 public fun now(unit: TimeUnit = TimeUnit.MILLISECONDS)=
88 unit.convert(time, TimeUnit.NANOSECONDS)
89
90 /**
91 * Moves the CoroutineContext's virtual clock forward by a specified amount of time.
92 *
93 * The returned delay-time can be larger than the specified delay-time if the code
94 * under test contains *blocking* Coroutines.
95 *
96 * @param delayTime The amount of time to move the CoroutineContext's clock forward.
97 * @param unit The [TimeUnit] in which [delayTime] and the return value is expressed.
98 * @return The amount of delay-time that this CoroutinesContext's clock has been forwarded.
99 */
100 public fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS): Long {
101 val oldTime = time
102 advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit.NANOSECONDS)
103 return unit.convert(time - oldTime, TimeUnit.NANOSECONDS)
104 }
105
106 /**
107 * Moves the CoroutineContext's clock-time to a particular moment in time.
108 *
109 * @param targetTime The point in time to which to move the CoroutineContext's clock.
110 * @param unit The [TimeUnit] in which [targetTime] is expressed.
111 */
112 fun advanceTimeTo(targetTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
113 val nanoTime = unit.toNanos(targetTime)
114 triggerActions(nanoTime)
115 if (nanoTime > time) time = nanoTime
116 }
117
118 /**
119 * Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
120 * before this CoroutineContext's present virtual clock-time.
121 */
122 public fun triggerActions() = triggerActions(time)
123
124 /**
125 * Cancels all not yet triggered actions. Be careful calling this, since it can seriously
126 * mess with your coroutines work. This method should usually be called on tear-down of a
127 * unit test.
128 */
129 public fun cancelAllActions() {
130 // An 'is-empty' test is required to avoid a NullPointerException in the 'clear()' method
131 if (!queue.isEmpty) queue.clear()
132 }
133
134 /**
135 * This method does nothing if there is one unhandled exception that satisfies the given predicate.
136 * Otherwise it throws an [AssertionError] with the given message.
137 *
138 * (this method will clear the list of unhandled exceptions)
139 *
140 * @param message Message of the [AssertionError]. Defaults to an empty String.
141 * @param predicate The predicate that must be satisfied.
142 */
143 public fun assertUnhandledException(message: String = "", predicate: (Throwable) -> Boolean) {
144 if (uncaughtExceptions.size != 1 || !predicate(uncaughtExceptions[0])) throw AssertionError(message)
145 uncaughtExceptions.clear()
146 }
147
148 /**
149 * This method does nothing if there are no unhandled exceptions or all of them satisfy the given predicate.
150 * Otherwise it throws an [AssertionError] with the given message.
151 *
152 * (this method will clear the list of unhandled exceptions)
153 *
154 * @param message Message of the [AssertionError]. Defaults to an empty String.
155 * @param predicate The predicate that must be satisfied.
156 */
157 public fun assertAllUnhandledExceptions(message: String = "", predicate: (Throwable) -> Boolean) {
158 if (!uncaughtExceptions.all(predicate)) throw AssertionError(message)
159 uncaughtExceptions.clear()
160 }
161
162 /**
163 * This method does nothing if one or more unhandled exceptions satisfy the given predicate.
164 * Otherwise it throws an [AssertionError] with the given message.
165 *
166 * (this method will clear the list of unhandled exceptions)
167 *
168 * @param message Message of the [AssertionError]. Defaults to an empty String.
169 * @param predicate The predicate that must be satisfied.
170 */
171 public fun assertAnyUnhandledException(message: String = "", predicate: (Throwable) -> Boolean) {
172 if (!uncaughtExceptions.any(predicate)) throw AssertionError(message)
173 uncaughtExceptions.clear()
174 }
175
176 /**
177 * This method does nothing if the list of unhandled exceptions satisfy the given predicate.
178 * Otherwise it throws an [AssertionError] with the given message.
179 *
180 * (this method will clear the list of unhandled exceptions)
181 *
182 * @param message Message of the [AssertionError]. Defaults to an empty String.
183 * @param predicate The predicate that must be satisfied.
184 */
185 public fun assertExceptions(message: String = "", predicate: (List<Throwable>) -> Boolean) {
186 if (!predicate(uncaughtExceptions)) throw AssertionError(message)
187 uncaughtExceptions.clear()
188 }
189
190 private fun post(block: Runnable) =
191 queue.addLast(TimedRunnable(block, counter++))
192
193 private fun postDelayed(block: Runnable, delayTime: Long) =
194 TimedRunnable(block, counter++, time + TimeUnit.MILLISECONDS.toNanos(delayTime))
195 .also {
196 queue.addLast(it)
197 }
198
199 private fun processNextEvent(): Long {
200 val current = queue.peek()
201 if (current != null) {
202 /** Automatically advance time for [EventLoop]-callbacks */
203 triggerActions(current.time)
204 }
205 return if (queue.isEmpty) Long.MAX_VALUE else 0L
206 }
207
208 private fun triggerActions(targetTime: Long) {
209 while (true) {
210 val current = queue.removeFirstIf { it.time <= targetTime } ?: break
211 // If the scheduled time is 0 (immediate) use current virtual time
212 if (current.time != 0L) time = current.time
213 current.run()
214 }
215 }
216
217 public override fun toString(): String = name ?: "TestCoroutineContext@$hexAddress"
218
219 private inner class Dispatcher : CoroutineDispatcher(), Delay, EventLoop {
220 override fun dispatch(context: CoroutineContext, block: Runnable) = post(block)
221
222 override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) {
223 postDelayed(Runnable {
224 with(continuation) { resumeUndispatched(Unit) }
225 }, unit.toMillis(time))
226 }
227
228 override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle {
229 val node = postDelayed(block, unit.toMillis(time))
230 return object : DisposableHandle {
231 override fun dispose() {
232 queue.remove(node)
233 }
234 }
235 }
236
237 override fun processNextEvent() = this@TestCoroutineContext.processNextEvent()
238
239 public override fun toString(): String = "Dispatcher(${this@TestCoroutineContext})"
240 }
241}
242
243private class TimedRunnable(
244 private val run: Runnable,
245 private val count: Long = 0,
246 @JvmField internal val time: Long = 0
247) : Comparable<TimedRunnable>, Runnable by run, ThreadSafeHeapNode {
248 override var index: Int = 0
249
250 override fun run() = run.run()
251
252 override fun compareTo(other: TimedRunnable) = if (time == other.time) {
253 count.compareTo(other.count)
254 } else {
255 time.compareTo(other.time)
256 }
257
258 override fun toString() = "TimedRunnable(time=$time, run=$run)"
259}
260
261/**
262 * Executes a block of code in which a unit-test can be written using the provided [TestCoroutineContext]. The provided
263 * [TestCoroutineContext] is available in the [testBody] as the `this` receiver.
264 *
265 * The [testBody] is executed and an [AssertionError] is thrown if the list of unhandled exceptions is not empty and
266 * contains any exception that is not a [CancellationException].
267 *
268 * If the [testBody] successfully executes one of the [TestCoroutineContext.assertAllUnhandledExceptions],
269 * [TestCoroutineContext.assertAnyUnhandledException], [TestCoroutineContext.assertUnhandledException] or
270 * [TestCoroutineContext.assertExceptions], the list of unhandled exceptions will have been cleared and this method will
271 * not throw an [AssertionError].
272 *
273 * @param testContext The provided [TestCoroutineContext]. If not specified, a default [TestCoroutineContext] will be
274 * provided instead.
275 * @param testBody The code of the unit-test.
276 */
277public fun withTestContext(testContext: TestCoroutineContext = TestCoroutineContext(), testBody: TestCoroutineContext.() -> Unit) {
278 with (testContext) {
279 testBody()
280
281 if (!exceptions.all { it is CancellationException }) {
282 throw AssertionError("Coroutine encountered unhandled exceptions:\n${exceptions}")
283 }
284 }
285}