blob: de47dfc68c9fdc7d713ad7e162b997e5f216eab0 [file] [log] [blame]
/*
* Copyright 2016-2017 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kotlinx.coroutines.experimental.nio
import kotlinx.coroutines.experimental.join
import kotlinx.coroutines.experimental.launch
import kotlinx.coroutines.experimental.runBlocking
import org.apache.commons.io.FileUtils
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousFileChannel
import java.nio.channels.AsynchronousServerSocketChannel
import java.nio.channels.AsynchronousSocketChannel
import java.nio.file.StandardOpenOption
import org.junit.Assert.*
class AsyncIOTest {
@Rule
@JvmField
val tmpDir = TemporaryFolder()
@Test
fun testFileChannels() {
val inputFile = tmpDir.newFile()
val outputFile = tmpDir.newFile()
FileUtils.writeStringToFile(
inputFile,
(1..100000).map(Int::toString).joinToString(""))
val input = AsynchronousFileChannel.open(inputFile.toPath())
val output =
AsynchronousFileChannel.open(
outputFile.toPath(),
StandardOpenOption.CREATE, StandardOpenOption.WRITE)
val buf = ByteBuffer.allocate(1024)
runBlocking {
var totalBytesRead = 0L
var totalBytesWritten = 0L
while (totalBytesRead < input.size()) {
while (buf.hasRemaining() && totalBytesRead < input.size()) {
// async read
totalBytesRead += input.aRead(buf, totalBytesRead)
}
buf.flip()
while (buf.hasRemaining()) {
// async write
totalBytesWritten += output.aWrite(buf, totalBytesWritten)
}
buf.clear()
}
}
assertTrue(FileUtils.contentEquals(inputFile, outputFile))
}
@Test
fun testNetworkChannels() = runBlocking {
val serverChannel =
AsynchronousServerSocketChannel
.open()
.bind(InetSocketAddress(0))
val serverPort = (serverChannel.localAddress as InetSocketAddress).port
val c1 = launch(context) {
val client = serverChannel.aAccept()
val buffer = ByteBuffer.allocate(2)
client.aRead(buffer)
buffer.flip()
assertEquals("OK", Charsets.UTF_8.decode(buffer).toString())
client.aWrite(Charsets.UTF_8.encode("123"))
client.close()
}
val c2 = launch(context) {
val connection =
AsynchronousSocketChannel.open()
// async calls
connection.aConnect(InetSocketAddress("127.0.0.1", serverPort))
connection.aWrite(Charsets.UTF_8.encode("OK"))
val buffer = ByteBuffer.allocate(3)
// async call
connection.aRead(buffer)
buffer.flip()
assertEquals("123", Charsets.UTF_8.decode(buffer).toString())
}
c1.join()
c2.join()
}
}