blob: 08d2b8d156f5392e94db9e0e86bab5c063a78cd7 [file] [log] [blame]
/*
* Copyright 2015, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.benchmarks.qps;
import static grpc.testing.Qpstest.SimpleRequest;
import static grpc.testing.Qpstest.SimpleResponse;
import static grpc.testing.TestServiceGrpc.TestServiceStub;
import static io.grpc.testing.integration.Util.loadCert;
import static java.lang.Math.max;
import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
import grpc.testing.Qpstest.PayloadType;
import grpc.testing.TestServiceGrpc;
import io.grpc.Channel;
import io.grpc.ChannelImpl;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.grpc.transport.netty.NegotiationType;
import io.grpc.transport.netty.NettyChannelBuilder;
import io.grpc.transport.okhttp.OkHttpChannelBuilder;
import io.netty.handler.ssl.SslContext;
import org.HdrHistogram.Histogram;
import org.HdrHistogram.HistogramIterationValue;
import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.IllegalFormatException;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
/**
* Runs lots of RPCs against a QPS Server to test for throughput and latency.
* It's a Java clone of the C version at
* https://github.com/grpc/grpc/blob/master/test/cpp/qps/client.cc
*/
public class QpsClient {
private static final Logger log = Logger.getLogger(QpsClient.class.getName());
// Can record values between 1 ns and 1 min (60 BILLION NS)
private static final long HISTOGRAM_MAX_VALUE = 60000000000L;
private static final int HISTOGRAM_PRECISION = 3;
private int clientChannels = 4;
private int concurrentCalls = 4;
private int payloadSize = 1;
private String serverHost = "127.0.0.1";
private int serverPort;
private boolean okhttp;
private boolean enableTls;
private boolean useTestCa;
// seconds
private int duration = 60;
// seconds
private int warmupDuration = 10;
public static void main(String... args) throws Exception {
new QpsClient().run(args);
}
/** Equivalent of "main", but non-static. */
public void run(String[] args) throws Exception {
if (!parseArgs(args)) {
return;
}
SimpleRequest req = SimpleRequest.newBuilder()
.setResponseType(PayloadType.COMPRESSABLE)
.setResponseSize(payloadSize)
.build();
List<Channel> channels = new ArrayList<Channel>(clientChannels);
for (int i = 0; i < clientChannels; i++) {
channels.add(newChannel());
}
warmup(req, channels.get(0));
final long startTime = System.nanoTime();
final long endTime = startTime + TimeUnit.SECONDS.toNanos(duration);
// Initiate the concurrent calls
List<Future<Histogram>> futures = new ArrayList<Future<Histogram>>(concurrentCalls);
for (int i = 0; i < concurrentCalls; i++) {
Channel channel = channels.get(i % clientChannels);
futures.add(doRpcs(channel, req, endTime));
}
// Wait for completion
List<Histogram> histograms = new ArrayList<Histogram>(futures.size());
for (Future<Histogram> future : futures) {
histograms.add(future.get());
}
long elapsedTime = System.nanoTime() - startTime;
printStats(merge(histograms), elapsedTime);
shutdown(channels);
}
private void shutdown(List<Channel> channels) {
for (Channel channel : channels) {
((ChannelImpl) channel).shutdown();
}
}
private void warmup(SimpleRequest req, Channel ch) throws Exception {
long end = System.nanoTime() + TimeUnit.SECONDS.toNanos(warmupDuration);
doRpcs(ch, req, end).get();
}
private Channel newChannel() throws IOException {
if (okhttp) {
if (enableTls) {
throw new IllegalStateException("TLS unsupported with okhttp");
}
return OkHttpChannelBuilder.forAddress(serverHost, serverPort)
.build();
}
SslContext context = null;
InetAddress address = InetAddress.getByName(serverHost);
NegotiationType negotiationType = enableTls ? NegotiationType.TLS : NegotiationType.PLAINTEXT;
if (enableTls && useTestCa) {
// Force the hostname to match the cert the server uses.
address = InetAddress.getByAddress("foo.test.google.fr", address.getAddress());
File cert = loadCert("ca.pem");
context = SslContext.newClientContext(cert);
}
return NettyChannelBuilder.forAddress(new InetSocketAddress(address, serverPort))
.negotiationType(negotiationType)
.sslContext(context)
.build();
}
private boolean parseArgs(String[] args) {
try {
boolean hasServerPort = false;
for (String arg : args) {
if (!arg.startsWith("--")) {
System.err.println("All arguments must start with '--': " + arg);
printUsage();
return false;
}
String[] pair = arg.substring(2).split("=", 2);
String key = pair[0];
String value = "";
if (pair.length == 2) {
value = pair[1];
}
if ("help".equals(key)) {
printUsage();
return false;
} else if ("server_port".equals(key)) {
serverPort = Integer.parseInt(value);
hasServerPort = true;
} else if ("server_host".equals(key)) {
serverHost = value;
} else if ("client_channels".equals(key)) {
clientChannels = max(Integer.parseInt(value), 1);
} else if ("concurrent_calls".equals(key)) {
concurrentCalls = max(Integer.parseInt(value), 1);
} else if ("payload_size".equals(key)) {
payloadSize = max(Integer.parseInt(value), 0);
} else if ("enable_tls".equals(key)) {
enableTls = true;
} else if ("use_testca".equals(key)) {
useTestCa = true;
} else if ("okhttp".equals(key)) {
okhttp = true;
} else if ("duration".equals(key)) {
duration = parseDuration(value);
} else if ("warmup_duration".equals(key)) {
warmupDuration = parseDuration(value);
} else {
System.err.println("Unrecognized argument '" + key + "'.");
}
}
if (!hasServerPort) {
System.err.println("'--server_port' was not specified.");
printUsage();
return false;
}
} catch (Exception e) {
e.printStackTrace();
printUsage();
return false;
}
return true;
}
private int parseDuration(String value) {
if (value == null || value.length() < 2) {
throw new IllegalArgumentException("value must be a number followed by a unit.");
}
char last = value.charAt(value.length() - 1);
int duration = Integer.parseInt(value.substring(0, value.length() - 1));
if (last == 's') {
return duration;
} else if (last == 'm') {
return duration * 60;
} else {
throw new IllegalArgumentException("Unknown unit " + last);
}
}
private void printUsage() {
QpsClient c = new QpsClient();
System.out.println(
"Usage: [ARGS...]"
+ "\n"
+ "\n --server_port=INT Port of the server. Required. No default."
+ "\n --server_host=STR Hostname of the server. Default " + c.serverHost
+ "\n --client_channels=INT Number of client channels. Default " + c.clientChannels
+ "\n --concurrent_calls=INT Number of concurrent calls. Default " + c.concurrentCalls
+ "\n --payload_size=INT Payload size in bytes. Default " + c.payloadSize
+ "\n --enable_tls Enable TLS. Default disabled."
+ "\n --use_testca Use the provided test certificate for TLS."
+ "\n --okhttp Use OkHttp as the transport. Default netty"
+ "\n --duration=TIME Duration of the benchmark in either seconds or minutes."
+ "\n For N seconds duration specify Ns and for minutes Nm. "
+ "\n Default " + c.duration + "s."
+ "\n --warmup_duration=TIME How long to run the warmup."
+ "\n Default " + c.warmupDuration + "s."
);
}
private Future<Histogram> doRpcs(Channel channel,
final SimpleRequest request,
final long endTime) {
final TestServiceStub stub = TestServiceGrpc.newStub(channel);
final Histogram histogram = new Histogram(HISTOGRAM_MAX_VALUE, HISTOGRAM_PRECISION);
final HistogramFuture future = new HistogramFuture(histogram);
stub.unaryCall(request, new StreamObserver<SimpleResponse>() {
long lastCall = System.nanoTime();
@Override
public void onValue(SimpleResponse value) {
PayloadType type = value.getPayload().getType();
int actualSize = value.getPayload().getBody().size();
if (!PayloadType.COMPRESSABLE.equals(type)) {
throw new RuntimeException("type was '" + type + "', expected '"
+ PayloadType.COMPRESSABLE + "'.");
}
if (payloadSize != actualSize) {
throw new RuntimeException("size was '" + actualSize + "', expected '"
+ payloadSize + "'");
}
}
@Override
public void onError(Throwable t) {
Status status = Status.fromThrowable(t);
System.err.println("onError called: " + status);
future.cancel(true);
}
@Override
public void onCompleted() {
long now = System.nanoTime();
histogram.recordValue(now - lastCall);
lastCall = now;
if (endTime > now) {
stub.unaryCall(request, this);
} else {
future.done();
}
}
});
return future;
}
private Histogram merge(List<Histogram> histograms) {
Histogram merged = new Histogram(HISTOGRAM_MAX_VALUE, HISTOGRAM_PRECISION);
for (Histogram histogram : histograms) {
for (HistogramIterationValue value : histogram.allValues()) {
long latency = value.getValueIteratedTo();
long count = value.getCountAtValueIteratedTo();
merged.recordValueWithCount(latency, count);
}
}
return merged;
}
private void printStats(Histogram histogram, long elapsedTime) {
double[] percentiles = {50, 90, 95, 99, 99.9, 99.99};
// Generate a comma-separated string of percentiles
StringBuilder header = new StringBuilder();
StringBuilder values = new StringBuilder();
header.append("Concurrent Calls, Channels, Payload Size, ");
values.append(String.format("%d, %d, %d, ", concurrentCalls, clientChannels, payloadSize));
for (double percentile : percentiles) {
header.append(percentile).append("%ile").append(", ");
values.append(histogram.getValueAtPercentile(percentile)).append(", ");
}
header.append("QPS");
values.append((histogram.getTotalCount() * 1000000000L) / elapsedTime);
System.out.println(header.toString());
System.out.println(values.toString());
}
private static class HistogramFuture implements Future<Histogram> {
private final Histogram histogram;
private boolean canceled;
private boolean done;
HistogramFuture(Histogram histogram) {
Preconditions.checkNotNull(histogram, "histogram");
this.histogram = histogram;
}
@Override
public synchronized boolean cancel(boolean mayInterruptIfRunning) {
if (!done && !canceled) {
canceled = true;
notifyAll();
return true;
}
return false;
}
@Override
public synchronized boolean isCancelled() {
return canceled;
}
@Override
public synchronized boolean isDone() {
return done || canceled;
}
@Override
public synchronized Histogram get() throws InterruptedException, ExecutionException {
while (!isDone() && !isCancelled()) {
wait();
}
if (isCancelled()) {
throw new CancellationException();
}
done = true;
return histogram;
}
@Override
public Histogram get(long timeout, TimeUnit unit) throws InterruptedException,
ExecutionException,
TimeoutException {
throw new UnsupportedOperationException();
}
private synchronized void done() {
done = true;
notifyAll();
}
}
}