IO: improve lookAhead functionality
also rename old lookAhead to consumeForEachByteRange
diff --git a/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt
index 8f59da9..2838f9e 100644
--- a/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt
+++ b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt
@@ -16,7 +16,7 @@
override val autoFlush: Boolean,
private val pool: ObjectPool<ReadWriteBufferState.Initial> = BufferObjectPool,
private val reservedSize: Int = RESERVED_SIZE
-) : ByteChannel {
+) : ByteChannel, LookAheadSuspendSession {
// internal constructor for reading of byte buffers
constructor(content: ByteBuffer) : this(false, BufferObjectNoPool, 0) {
state = ReadWriteBufferState.Initial(content.slice(), 0).apply {
@@ -941,12 +941,85 @@
* Never invokes [visitor] with empty buffer unless [last] = true. Invokes visitor with last = true at most once
* even if there are remaining bytes and visitor returned true.
*/
- override suspend fun lookAhead(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean) {
- if (lookAheadFast(false, visitor)) return
- lookAheadSuspend(visitor)
+ override suspend fun consumeEachBufferRange(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean) {
+ if (consumeEachBufferRangeFast(false, visitor)) return
+ consumeEachBufferRangeSuspend(visitor)
}
- private inline fun lookAheadFast(last: Boolean, visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean): Boolean {
+ override fun <R> lookAhead(visitor: LookAheadSession.() -> R): R {
+ if (state === ReadWriteBufferState.Terminated) {
+ return visitor(TerminatedLookAhead)
+ }
+
+ var result: R? = null
+ val rc = reading {
+ result = visitor(this@ByteBufferChannel)
+ true
+ }
+
+ if (!rc) {
+ return visitor(TerminatedLookAhead)
+ }
+
+ return result!!
+ }
+
+ suspend override fun <R> lookAheadSuspend(visitor: suspend LookAheadSuspendSession.() -> R): R {
+ if (state === ReadWriteBufferState.Terminated) {
+ return visitor(TerminatedLookAhead)
+ }
+
+ var result: R? = null
+ val rc = reading {
+ result = visitor(this@ByteBufferChannel)
+ true
+ }
+
+ if (!rc) {
+ if (closed != null || state === ReadWriteBufferState.Terminated) return visitor(TerminatedLookAhead)
+ result = visitor(this)
+ if (!state.idle) {
+ restoreStateAfterRead()
+ tryTerminate()
+ }
+ }
+
+ return result!!
+ }
+
+ override fun consumed(n: Int) {
+ require(n >= 0)
+
+ state.let { s ->
+ if (!s.capacity.tryReadExact(n)) throw IllegalStateException("Unable to consume $n bytes: not enough available bytes")
+ s.readBuffer.bytesRead(s.capacity, n)
+ }
+ }
+
+ suspend override fun awaitAtLeast(n: Int) {
+ if (readSuspend(n) && state.idle) {
+ setupStateForRead()
+ }
+ }
+
+ override fun request(skip: Int, atLeast: Int): ByteBuffer? {
+ return state.let { s ->
+ val available = s.capacity.availableForRead
+ val rp = readPosition
+
+ if (available < atLeast + skip) return null
+ if (s.idle || (s !is ReadWriteBufferState.Reading && s !is ReadWriteBufferState.ReadingWriting)) return null
+
+ val buffer = s.readBuffer
+
+ val position = buffer.carryIndex(rp + skip)
+ buffer.prepareBuffer(readByteOrder, position, available - skip)
+
+ if (buffer.remaining() >= atLeast) buffer else null
+ }
+ }
+
+ private inline fun consumeEachBufferRangeFast(last: Boolean, visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean): Boolean {
if (state === ReadWriteBufferState.Terminated && !last) return false
val rc = reading {
@@ -974,11 +1047,11 @@
return rc
}
- private suspend fun lookAheadSuspend(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean): Boolean {
+ private suspend fun consumeEachBufferRangeSuspend(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean): Boolean {
var last = false
do {
- if (lookAheadFast(last, visitor)) return true
+ if (consumeEachBufferRangeFast(last, visitor)) return true
if (last) return false
if (!readSuspend(1)) {
last = true
@@ -1006,7 +1079,7 @@
var unicodeStarted = false
var eol = false
- lookAheadFast(false) { buffer, last ->
+ consumeEachBufferRangeFast(false) { buffer, last ->
var forceConsume = false
val rejected = !buffer.decodeASCII { ch ->
@@ -1058,7 +1131,7 @@
var consumed1 = 0
var eol = false
- lookAheadFast(false) { buffer, last ->
+ consumeEachBufferRangeFast(false) { buffer, last ->
var forceConsume = false
val rc = buffer.decodeUTF8 { ch ->
@@ -1108,7 +1181,7 @@
var eol = false
var wrap = 0
- lookAheadSuspend { buffer, last ->
+ consumeEachBufferRangeSuspend { buffer, last ->
var forceConsume = false
val rc = buffer.decodeUTF8 { ch ->
@@ -1298,6 +1371,17 @@
private val Closed = updater(ByteBufferChannel::closed)
}
+ private object TerminatedLookAhead : LookAheadSuspendSession {
+ override fun consumed(n: Int) {
+ if (n > 0) throw IllegalStateException("Unable to mark $n bytes consumed for already terminated channel")
+ }
+
+ override fun request(skip: Int, atLeast: Int) = null
+
+ suspend override fun awaitAtLeast(n: Int) {
+ }
+ }
+
private class ClosedElement(val cause: Throwable?) {
val sendException: Throwable
get() = cause ?: ClosedWriteChannelException("The channel was closed")
diff --git a/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteReadChannel.kt b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteReadChannel.kt
index 0759bae..299bd36 100644
--- a/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteReadChannel.kt
+++ b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteReadChannel.kt
@@ -22,6 +22,8 @@
*/
public val isClosedForRead: Boolean
+ public val isClosedForWrite: Boolean
+
/**
* Byte order that is used for multi-byte read operations
* (such as [readShort], [readInt], [readLong], [readFloat], and [readDouble]).
@@ -92,7 +94,13 @@
*/
suspend fun readFloat(): Float
- suspend fun lookAhead(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean)
+ /**
+ * For every available bytes range invokes [visitor] function until it return false or end of stream encountered
+ */
+ suspend fun consumeEachBufferRange(visitor: (buffer: ByteBuffer, last: Boolean) -> Boolean)
+
+ fun <R> lookAhead(visitor: LookAheadSession.() -> R): R
+ suspend fun <R> lookAheadSuspend(visitor: suspend LookAheadSuspendSession.() -> R): R
/**
* Reads a line of UTF-8 characters to the specified [out] buffer up to [limit] characters.
diff --git a/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/LookAheadSession.kt b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/LookAheadSession.kt
new file mode 100644
index 0000000..045c096
--- /dev/null
+++ b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/LookAheadSession.kt
@@ -0,0 +1,41 @@
+package kotlinx.coroutines.experimental.io
+
+interface LookAheadSession {
+ /**
+ * Marks [n] bytes as consumed so the corresponding range becomes available for writing
+ */
+ fun consumed(n: Int)
+
+ /**
+ * Request byte buffer range skipping [skip] bytes and [atLeast] bytes length
+ * @return byte buffer for the requested range or null if it is impossible to provide such a buffer
+ *
+ * There are the following reasons for this function to return `null`:
+ * - not enough bytes available yet (should be at least `skip + atLeast` bytes available)
+ * - due to buffer fragmentation is is impossible to represent the requested range as a single byte buffer
+ * - end of stream encountered and all bytes were consumed
+ * - channel has been closed with an exception so buffer has been recycled
+ */
+ fun request(skip: Int, atLeast: Int): ByteBuffer?
+}
+
+interface LookAheadSuspendSession : LookAheadSession {
+ /**
+ * Suspend until [n] bytes become available or end of stream encountered (possibly due to exceptional close)
+ */
+ suspend fun awaitAtLeast(n: Int)
+}
+
+inline fun LookAheadSession.consumeEachRemaining(visitor: (ByteBuffer) -> Boolean) {
+ do {
+ val cont = request(0, 1)?.let {
+ val s = it.remaining()
+ val rc = visitor(it)
+ consumed(s)
+ rc
+ } ?: false
+
+ if (!cont) break
+ } while (true)
+}
+
diff --git a/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/internal/Utils.kt b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/internal/Utils.kt
index a5be826..6bc92c5 100644
--- a/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/internal/Utils.kt
+++ b/core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/internal/Utils.kt
@@ -1,10 +1,8 @@
package kotlinx.coroutines.experimental.io.internal
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicIntegerFieldUpdater
-import java.util.concurrent.atomic.AtomicLongFieldUpdater
-import java.util.concurrent.atomic.AtomicReferenceFieldUpdater
-import kotlin.reflect.KProperty1
+import java.nio.*
+import java.util.concurrent.atomic.*
+import kotlin.reflect.*
internal fun ByteBuffer.isEmpty() = !hasRemaining()
@@ -24,3 +22,64 @@
try { System.getProperty("kotlinx.coroutines.io.$name") }
catch (e: SecurityException) { null }
?.toIntOrNull() ?: default
+
+
+@Suppress("LoopToCallChain")
+internal fun ByteBuffer.indexOfPartial(sub: ByteBuffer): Int {
+ val subPosition = sub.position()
+ val subSize = sub.remaining()
+ val first = sub[subPosition]
+ val limit = limit()
+
+ outer@for (idx in position() until limit) {
+ if (get(idx) == first) {
+ for (j in 1 until subSize) {
+ if (idx + j == limit) break
+ if (get(idx + j) != sub.get(subPosition + j)) continue@outer
+ }
+ return idx - position()
+ }
+ }
+
+ return -1
+}
+
+@Suppress("LoopToCallChain")
+internal fun ByteBuffer.startsWith(prefix: ByteBuffer, prefixSkip: Int = 0): Boolean {
+ val size = minOf(remaining(), prefix.remaining() - prefixSkip)
+ if (size <= 0) return false
+
+ val position = position()
+ val prefixPosition = prefix.position() + prefixSkip
+
+ for (i in 0 until size) {
+ if (get(position + i) != prefix.get(prefixPosition + i)) return false
+ }
+
+ return true
+}
+
+internal fun ByteBuffer.putAtMost(src: ByteBuffer, n: Int = src.remaining()): Int {
+ val rem = remaining()
+ val srcRem = src.remaining()
+
+ return when {
+ srcRem <= rem && srcRem <= n -> {
+ put(src)
+ srcRem
+ }
+ else -> {
+ val size = minOf(rem, srcRem, n)
+ for (idx in 1..size) {
+ put(src.get())
+ }
+ size
+ }
+ }
+}
+
+internal fun ByteBuffer.putLimited(src: ByteBuffer, limit: Int = limit()): Int {
+ return putAtMost(src, limit - src.position())
+}
+
+internal fun ByteArray.asByteBuffer(offset: Int = 0, length: Int = size): ByteBuffer = ByteBuffer.wrap(this, offset, length)
\ No newline at end of file