Add debounce with selector and kotlin.time (#2336)

Co-authored-by: Miguel Kano <miguel.g.kano@gmail.com>
Co-authored-by: Vsevolod Tolstopyatov <qwwdfsad@gmail.com>
diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
index c3eddb9..06f4396 100644
--- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
+++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
@@ -942,7 +942,9 @@
 	public static final fun count (Lkotlinx/coroutines/flow/Flow;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
 	public static final fun count (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
 	public static final fun debounce (Lkotlinx/coroutines/flow/Flow;J)Lkotlinx/coroutines/flow/Flow;
+	public static final fun debounce (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/flow/Flow;
 	public static final fun debounce-8GFy2Ro (Lkotlinx/coroutines/flow/Flow;D)Lkotlinx/coroutines/flow/Flow;
+	public static final fun debounceDuration (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/flow/Flow;
 	public static final fun delayEach (Lkotlinx/coroutines/flow/Flow;J)Lkotlinx/coroutines/flow/Flow;
 	public static final fun delayFlow (Lkotlinx/coroutines/flow/Flow;J)Lkotlinx/coroutines/flow/Flow;
 	public static final fun distinctUntilChanged (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow;
diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt b/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt
index aa55fea..c95b4be 100644
--- a/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt
+++ b/kotlinx-coroutines-core/common/src/flow/operators/Delay.kt
@@ -64,39 +64,61 @@
  */
 @FlowPreview
 public fun <T> Flow<T>.debounce(timeoutMillis: Long): Flow<T> {
-    require(timeoutMillis > 0) { "Debounce timeout should be positive" }
-    return scopedFlow { downstream ->
-        // Actually Any, KT-30796
-        val values = produce<Any?>(capacity = Channel.CONFLATED) {
-            collect { value -> send(value ?: NULL) }
-        }
-        var lastValue: Any? = null
-        while (lastValue !== DONE) {
-            select<Unit> {
-                // Should be receiveOrClosed when boxing issues are fixed
-                values.onReceiveOrNull {
-                    if (it == null) {
-                        if (lastValue != null) downstream.emit(NULL.unbox(lastValue))
-                        lastValue = DONE
-                    } else {
-                        lastValue = it
-                    }
-                }
-
-                lastValue?.let { value ->
-                    // set timeout when lastValue != null
-                    onTimeout(timeoutMillis) {
-                        lastValue = null // Consume the value
-                        downstream.emit(NULL.unbox(value))
-                    }
-                }
-            }
-        }
-    }
+    require(timeoutMillis >= 0L) { "Debounce timeout should not be negative" }
+    if (timeoutMillis == 0L) return this
+    return debounceInternal { timeoutMillis }
 }
 
 /**
  * Returns a flow that mirrors the original flow, but filters out values
+ * that are followed by the newer values within the given [timeout][timeoutMillis].
+ * The latest value is always emitted.
+ *
+ * A variation of [debounce] that allows specifying the timeout value dynamically.
+ *
+ * Example:
+ *
+ * ```kotlin
+ * flow {
+ *     emit(1)
+ *     delay(90)
+ *     emit(2)
+ *     delay(90)
+ *     emit(3)
+ *     delay(1010)
+ *     emit(4)
+ *     delay(1010)
+ *     emit(5)
+ * }.debounce {
+ *     if (it == 1) {
+ *         0L
+ *     } else {
+ *         1000L
+ *     }
+ * }
+ * ```
+ * <!--- KNIT example-delay-02.kt -->
+ *
+ * produces the following emissions
+ *
+ * ```text
+ * 1, 3, 4, 5
+ * ```
+ * <!--- TEST -->
+ *
+ * Note that the resulting flow does not emit anything as long as the original flow emits
+ * items faster than every [timeoutMillis] milliseconds.
+ *
+ * @param timeoutMillis [T] is the emitted value and the return value is timeout in milliseconds.
+ */
+@FlowPreview
+@OptIn(kotlin.experimental.ExperimentalTypeInference::class)
+@OverloadResolutionByLambdaReturnType
+public fun <T> Flow<T>.debounce(timeoutMillis: (T) -> Long): Flow<T> =
+    debounceInternal(timeoutMillis)
+
+/**
+ * Returns a flow that mirrors the original flow, but filters out values
  * that are followed by the newer values within the given [timeout].
  * The latest value is always emitted.
  *
@@ -129,7 +151,104 @@
  */
 @ExperimentalTime
 @FlowPreview
-public fun <T> Flow<T>.debounce(timeout: Duration): Flow<T> = debounce(timeout.toDelayMillis())
+public fun <T> Flow<T>.debounce(timeout: Duration): Flow<T> =
+    debounce(timeout.toDelayMillis())
+
+/**
+ * Returns a flow that mirrors the original flow, but filters out values
+ * that are followed by the newer values within the given [timeout].
+ * The latest value is always emitted.
+ *
+ * A variation of [debounce] that allows specifying the timeout value dynamically.
+ *
+ * Example:
+ *
+ * ```kotlin
+ * flow {
+ *     emit(1)
+ *     delay(90.milliseconds)
+ *     emit(2)
+ *     delay(90.milliseconds)
+ *     emit(3)
+ *     delay(1010.milliseconds)
+ *     emit(4)
+ *     delay(1010.milliseconds)
+ *     emit(5)
+ * }.debounce {
+ *     if (it == 1) {
+ *         0.milliseconds
+ *     } else {
+ *         1000.milliseconds
+ *     }
+ * }
+ * ```
+ * <!--- KNIT example-delay-duration-02.kt -->
+ *
+ * produces the following emissions
+ *
+ * ```text
+ * 1, 3, 4, 5
+ * ```
+ * <!--- TEST -->
+ *
+ * Note that the resulting flow does not emit anything as long as the original flow emits
+ * items faster than every [timeout] unit.
+ *
+ * @param timeout [T] is the emitted value and the return value is timeout in [Duration].
+ */
+@ExperimentalTime
+@FlowPreview
+@JvmName("debounceDuration")
+@OptIn(kotlin.experimental.ExperimentalTypeInference::class)
+@OverloadResolutionByLambdaReturnType
+public fun <T> Flow<T>.debounce(timeout: (T) -> Duration): Flow<T> =
+    debounceInternal { emittedItem ->
+        timeout(emittedItem).toDelayMillis()
+    }
+
+private fun <T> Flow<T>.debounceInternal(timeoutMillisSelector: (T) -> Long) : Flow<T> =
+    scopedFlow { downstream ->
+        // Produce the values using the default (rendezvous) channel
+        // Note: the actual type is Any, KT-30796
+        val values = produce<Any?> {
+            collect { value -> send(value ?: NULL) }
+        }
+        // Now consume the values
+        var lastValue: Any? = null
+        while (lastValue !== DONE) {
+            var timeoutMillis = 0L // will be always computed when lastValue != null
+            // Compute timeout for this value
+            if (lastValue != null) {
+                timeoutMillis = timeoutMillisSelector(NULL.unbox(lastValue))
+                require(timeoutMillis >= 0L) { "Debounce timeout should not be negative" }
+                if (timeoutMillis == 0L) {
+                    downstream.emit(NULL.unbox(lastValue))
+                    lastValue = null // Consume the value
+                }
+            }
+            // assert invariant: lastValue != null implies timeoutMillis > 0
+            assert { lastValue == null || timeoutMillis > 0 }
+            // wait for the next value with timeout
+            select<Unit> {
+                // Set timeout when lastValue exists and is not consumed yet
+                if (lastValue != null) {
+                    onTimeout(timeoutMillis) {
+                        downstream.emit(NULL.unbox(lastValue))
+                        lastValue = null // Consume the value
+                    }
+                }
+                // Should be receiveOrClosed when boxing issues are fixed
+                values.onReceiveOrNull { value ->
+                    if (value == null) {
+                        if (lastValue != null) downstream.emit(NULL.unbox(lastValue))
+                        lastValue = DONE
+                    } else {
+                        lastValue = value
+                    }
+                }
+            }
+        }
+    }
 
 /**
  * Returns a flow that emits only the latest value emitted by the original flow during the given sampling [period][periodMillis].
@@ -144,7 +263,7 @@
  *     }
  * }.sample(200)
  * ```
- * <!--- KNIT example-delay-02.kt -->
+ * <!--- KNIT example-delay-03.kt -->
  *
  * produces the following emissions
  *
@@ -152,7 +271,7 @@
  * 1, 3, 5, 7, 9
  * ```
  * <!--- TEST -->
- * 
+ *
  * Note that the latest element is not emitted if it does not fit into the sampling window.
  */
 @FlowPreview
@@ -215,7 +334,7 @@
  *     }
  * }.sample(200.milliseconds)
  * ```
- * <!--- KNIT example-delay-duration-02.kt -->
+ * <!--- KNIT example-delay-duration-03.kt -->
  *
  * produces the following emissions
  *
diff --git a/kotlinx-coroutines-core/common/test/flow/operators/DebounceTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/DebounceTest.kt
index 4065671..ce75e59 100644
--- a/kotlinx-coroutines-core/common/test/flow/operators/DebounceTest.kt
+++ b/kotlinx-coroutines-core/common/test/flow/operators/DebounceTest.kt
@@ -11,7 +11,7 @@
 
 class DebounceTest : TestBase() {
     @Test
-    public fun testBasic() = withVirtualTime {
+    fun testBasic() = withVirtualTime {
         expect(1)
         val flow = flow {
             expect(3)
@@ -159,7 +159,7 @@
             expect(2)
             throw TestException()
         }.flowOn(NamedDispatchers("source")).debounce(Long.MAX_VALUE).map {
-                expectUnreached()
+            expectUnreached()
         }
         assertFailsWith<TestException>(flow)
         finish(3)
@@ -175,7 +175,6 @@
             expect(2)
             yield()
             throw TestException()
-            it
         }
 
         assertFailsWith<TestException>(flow)
@@ -193,7 +192,6 @@
             expect(2)
             yield()
             throw TestException()
-            it
         }
 
         assertFailsWith<TestException>(flow)
@@ -202,7 +200,7 @@
 
     @ExperimentalTime
     @Test
-    public fun testDurationBasic() = withVirtualTime {
+    fun testDurationBasic() = withVirtualTime {
         expect(1)
         val flow = flow {
             expect(3)
@@ -223,4 +221,102 @@
         assertEquals(listOf("A", "D", "E"), result)
         finish(5)
     }
+
+    @ExperimentalTime
+    @Test
+    fun testDebounceSelectorBasic() = withVirtualTime {
+        expect(1)
+        val flow = flow {
+            expect(3)
+            emit(1)
+            delay(90)
+            emit(2)
+            delay(90)
+            emit(3)
+            delay(1010)
+            emit(4)
+            delay(1010)
+            emit(5)
+            expect(4)
+        }
+
+        expect(2)
+        val result = flow.debounce {
+            if (it == 1) {
+                0
+            } else {
+                1000
+            }
+        }.toList()
+
+        assertEquals(listOf(1, 3, 4, 5), result)
+        finish(5)
+    }
+
+    @Test
+    fun testZeroDebounceTime() = withVirtualTime {
+        expect(1)
+        val flow = flow {
+            expect(3)
+            emit("A")
+            emit("B")
+            emit("C")
+            expect(4)
+        }
+
+        expect(2)
+        val result = flow.debounce(0).toList()
+
+        assertEquals(listOf("A", "B", "C"), result)
+        finish(5)
+    }
+
+    @ExperimentalTime
+    @Test
+    fun testZeroDebounceTimeSelector() = withVirtualTime {
+        expect(1)
+        val flow = flow {
+            expect(3)
+            emit("A")
+            emit("B")
+            expect(4)
+        }
+
+        expect(2)
+        val result = flow.debounce { 0 }.toList()
+
+        assertEquals(listOf("A", "B"), result)
+        finish(5)
+    }
+
+    @ExperimentalTime
+    @Test
+    fun testDebounceDurationSelectorBasic() = withVirtualTime {
+        expect(1)
+        val flow = flow {
+            expect(3)
+            emit("A")
+            delay(1500.milliseconds)
+            emit("B")
+            delay(500.milliseconds)
+            emit("C")
+            delay(250.milliseconds)
+            emit("D")
+            delay(2000.milliseconds)
+            emit("E")
+            expect(4)
+        }
+
+        expect(2)
+        val result = flow.debounce {
+            if (it == "C") {
+                0.milliseconds
+            } else {
+                1000.milliseconds
+            }
+        }.toList()
+
+        assertEquals(listOf("A", "C", "D", "E"), result)
+        finish(5)
+    }
 }
diff --git a/kotlinx-coroutines-core/jvm/test/examples/example-delay-02.kt b/kotlinx-coroutines-core/jvm/test/examples/example-delay-02.kt
index 1b6b12f..f74422e 100644
--- a/kotlinx-coroutines-core/jvm/test/examples/example-delay-02.kt
+++ b/kotlinx-coroutines-core/jvm/test/examples/example-delay-02.kt
@@ -11,9 +11,20 @@
 fun main() = runBlocking {
 
 flow {
-    repeat(10) {
-        emit(it)
-        delay(110)
+    emit(1)
+    delay(90)
+    emit(2)
+    delay(90)
+    emit(3)
+    delay(1010)
+    emit(4)
+    delay(1010)
+    emit(5)
+}.debounce {
+    if (it == 1) {
+        0L
+    } else {
+        1000L
     }
-}.sample(200)
+}
 .toList().joinToString().let { println(it) } }
diff --git a/kotlinx-coroutines-core/jvm/test/examples/example-delay-03.kt b/kotlinx-coroutines-core/jvm/test/examples/example-delay-03.kt
new file mode 100644
index 0000000..edaea74
--- /dev/null
+++ b/kotlinx-coroutines-core/jvm/test/examples/example-delay-03.kt
@@ -0,0 +1,19 @@
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
+// This file was automatically generated from Delay.kt by Knit tool. Do not edit.
+package kotlinx.coroutines.examples.exampleDelay03
+
+import kotlinx.coroutines.*
+import kotlinx.coroutines.flow.*
+
+fun main() = runBlocking {
+
+flow {
+    repeat(10) {
+        emit(it)
+        delay(110)
+    }
+}.sample(200)
+.toList().joinToString().let { println(it) } }
diff --git a/kotlinx-coroutines-core/jvm/test/examples/example-delay-duration-02.kt b/kotlinx-coroutines-core/jvm/test/examples/example-delay-duration-02.kt
index e43dfd1..10ba88a 100644
--- a/kotlinx-coroutines-core/jvm/test/examples/example-delay-duration-02.kt
+++ b/kotlinx-coroutines-core/jvm/test/examples/example-delay-duration-02.kt
@@ -13,9 +13,20 @@
 fun main() = runBlocking {
 
 flow {
-    repeat(10) {
-        emit(it)
-        delay(110.milliseconds)
+    emit(1)
+    delay(90.milliseconds)
+    emit(2)
+    delay(90.milliseconds)
+    emit(3)
+    delay(1010.milliseconds)
+    emit(4)
+    delay(1010.milliseconds)
+    emit(5)
+}.debounce {
+    if (it == 1) {
+        0.milliseconds
+    } else {
+        1000.milliseconds
     }
-}.sample(200.milliseconds)
+}
 .toList().joinToString().let { println(it) } }
diff --git a/kotlinx-coroutines-core/jvm/test/examples/example-delay-duration-03.kt b/kotlinx-coroutines-core/jvm/test/examples/example-delay-duration-03.kt
new file mode 100644
index 0000000..5fa980a
--- /dev/null
+++ b/kotlinx-coroutines-core/jvm/test/examples/example-delay-duration-03.kt
@@ -0,0 +1,21 @@
+@file:OptIn(ExperimentalTime::class)
+/*
+ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
+ */
+
+// This file was automatically generated from Delay.kt by Knit tool. Do not edit.
+package kotlinx.coroutines.examples.exampleDelayDuration03
+
+import kotlin.time.*
+import kotlinx.coroutines.*
+import kotlinx.coroutines.flow.*
+
+fun main() = runBlocking {
+
+flow {
+    repeat(10) {
+        emit(it)
+        delay(110.milliseconds)
+    }
+}.sample(200.milliseconds)
+.toList().joinToString().let { println(it) } }
diff --git a/kotlinx-coroutines-core/jvm/test/examples/test/FlowDelayTest.kt b/kotlinx-coroutines-core/jvm/test/examples/test/FlowDelayTest.kt
index 226d31c..99e72eb 100644
--- a/kotlinx-coroutines-core/jvm/test/examples/test/FlowDelayTest.kt
+++ b/kotlinx-coroutines-core/jvm/test/examples/test/FlowDelayTest.kt
@@ -17,6 +17,13 @@
     }
 
     @Test
+    fun testExampleDelay02() {
+        test("ExampleDelay02") { kotlinx.coroutines.examples.exampleDelay02.main() }.verifyLines(
+            "1, 3, 4, 5"
+        )
+    }
+
+    @Test
     fun testExampleDelayDuration01() {
         test("ExampleDelayDuration01") { kotlinx.coroutines.examples.exampleDelayDuration01.main() }.verifyLines(
             "3, 4, 5"
@@ -24,15 +31,22 @@
     }
 
     @Test
-    fun testExampleDelay02() {
-        test("ExampleDelay02") { kotlinx.coroutines.examples.exampleDelay02.main() }.verifyLines(
+    fun testExampleDelayDuration02() {
+        test("ExampleDelayDuration02") { kotlinx.coroutines.examples.exampleDelayDuration02.main() }.verifyLines(
+            "1, 3, 4, 5"
+        )
+    }
+
+    @Test
+    fun testExampleDelay03() {
+        test("ExampleDelay03") { kotlinx.coroutines.examples.exampleDelay03.main() }.verifyLines(
             "1, 3, 5, 7, 9"
         )
     }
 
     @Test
-    fun testExampleDelayDuration02() {
-        test("ExampleDelayDuration02") { kotlinx.coroutines.examples.exampleDelayDuration02.main() }.verifyLines(
+    fun testExampleDelayDuration03() {
+        test("ExampleDelayDuration03") { kotlinx.coroutines.examples.exampleDelayDuration03.main() }.verifyLines(
             "1, 3, 5, 7, 9"
         )
     }