blob: 6f4f057cadb8ad05684b9831cd1e070029db94b2 [file] [log] [blame]
/*
* Copyright 2014, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package io.grpc.transport.okhttp;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.squareup.okhttp.internal.spdy.ErrorCode;
import com.squareup.okhttp.internal.spdy.FrameReader;
import com.squareup.okhttp.internal.spdy.Header;
import com.squareup.okhttp.internal.spdy.HeadersMode;
import com.squareup.okhttp.internal.spdy.Http20Draft16;
import com.squareup.okhttp.internal.spdy.Settings;
import com.squareup.okhttp.internal.spdy.Variant;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.transport.ClientStreamListener;
import io.grpc.transport.ClientTransport;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.ByteString;
import okio.Okio;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.net.ssl.SSLSocketFactory;
/**
* A okhttp-based {@link ClientTransport} implementation.
*/
public class OkHttpClientTransport implements ClientTransport {
/** The default initial window size in HTTP/2 is 64 KiB for the stream and connection. */
@VisibleForTesting
static final int DEFAULT_INITIAL_WINDOW_SIZE = 64 * 1024;
private static final Map<ErrorCode, Status> ERROR_CODE_TO_STATUS;
private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName());
static {
Map<ErrorCode, Status> errorToStatus = new HashMap<ErrorCode, Status>();
errorToStatus.put(ErrorCode.NO_ERROR, Status.OK);
errorToStatus.put(ErrorCode.PROTOCOL_ERROR,
Status.INTERNAL.withDescription("Protocol error"));
errorToStatus.put(ErrorCode.INVALID_STREAM,
Status.INTERNAL.withDescription("Invalid stream"));
errorToStatus.put(ErrorCode.UNSUPPORTED_VERSION,
Status.INTERNAL.withDescription("Unsupported version"));
errorToStatus.put(ErrorCode.STREAM_IN_USE,
Status.INTERNAL.withDescription("Stream in use"));
errorToStatus.put(ErrorCode.STREAM_ALREADY_CLOSED,
Status.INTERNAL.withDescription("Stream already closed"));
errorToStatus.put(ErrorCode.INTERNAL_ERROR,
Status.INTERNAL.withDescription("Internal error"));
errorToStatus.put(ErrorCode.FLOW_CONTROL_ERROR,
Status.INTERNAL.withDescription("Flow control error"));
errorToStatus.put(ErrorCode.STREAM_CLOSED,
Status.INTERNAL.withDescription("Stream closed"));
errorToStatus.put(ErrorCode.FRAME_TOO_LARGE,
Status.INTERNAL.withDescription("Frame too large"));
errorToStatus.put(ErrorCode.REFUSED_STREAM,
Status.INTERNAL.withDescription("Refused stream"));
errorToStatus.put(ErrorCode.CANCEL, Status.CANCELLED.withDescription("Cancelled"));
errorToStatus.put(ErrorCode.COMPRESSION_ERROR,
Status.INTERNAL.withDescription("Compression error"));
errorToStatus.put(ErrorCode.INVALID_CREDENTIALS,
Status.PERMISSION_DENIED.withDescription("Invalid credentials"));
ERROR_CODE_TO_STATUS = Collections.unmodifiableMap(errorToStatus);
}
private final InetSocketAddress address;
private final String authorityHost;
private final String defaultAuthority;
private Listener listener;
private FrameReader frameReader;
private AsyncFrameWriter frameWriter;
private OutboundFlowController outboundFlow;
private final Object lock = new Object();
@GuardedBy("lock")
private int nextStreamId;
private final Map<Integer, OkHttpClientStream> streams =
Collections.synchronizedMap(new HashMap<Integer, OkHttpClientStream>());
private final Executor executor;
private int connectionUnacknowledgedBytesRead;
private ClientFrameHandler clientFrameHandler;
// The status used to finish all active streams when the transport is closed.
@GuardedBy("lock")
private boolean goAway;
@GuardedBy("lock")
private Status goAwayStatus;
@GuardedBy("lock")
private boolean stopped;
private SSLSocketFactory sslSocketFactory;
OkHttpClientTransport(InetSocketAddress address, String authorityHost, Executor executor,
SSLSocketFactory sslSocketFactory) {
this.address = Preconditions.checkNotNull(address);
this.authorityHost = authorityHost;
defaultAuthority = authorityHost + ":" + address.getPort();
this.executor = Preconditions.checkNotNull(executor);
// Client initiated streams are odd, server initiated ones are even. Server should not need to
// use it. We start clients at 3 to avoid conflicting with HTTP negotiation.
nextStreamId = 3;
this.sslSocketFactory = sslSocketFactory;
}
/**
* Create a transport connected to a fake peer for test.
*/
@VisibleForTesting
OkHttpClientTransport(Executor executor, FrameReader frameReader, AsyncFrameWriter frameWriter,
int nextStreamId) {
address = null;
authorityHost = null;
defaultAuthority = "notarealauthority:80";
this.executor = Preconditions.checkNotNull(executor);
this.frameReader = Preconditions.checkNotNull(frameReader);
this.frameWriter = Preconditions.checkNotNull(frameWriter);
this.outboundFlow = new OutboundFlowController(this, frameWriter);
this.nextStreamId = nextStreamId;
}
@Override
public OkHttpClientStream newStream(MethodDescriptor<?, ?> method,
Metadata.Headers headers,
ClientStreamListener listener) {
Preconditions.checkNotNull(method, "method");
Preconditions.checkNotNull(headers, "headers");
Preconditions.checkNotNull(listener, "listener");
OkHttpClientStream clientStream =
OkHttpClientStream.newStream(listener, frameWriter, this, outboundFlow);
String defaultPath = "/" + method.getName();
List<Header> requestHeaders =
Headers.createRequestHeaders(headers, defaultPath, defaultAuthority);
synchronized (lock) {
if (goAway) {
throw new IllegalStateException("Transport not running", goAwayStatus.asRuntimeException());
}
assignStreamId(clientStream);
frameWriter.synStream(false, false, clientStream.id(), 0, requestHeaders);
}
return clientStream;
}
@Override
public void start(Listener listener) {
this.listener = Preconditions.checkNotNull(listener, "listener");
// We set host to null for test.
if (address != null) {
BufferedSource source;
BufferedSink sink;
try {
Socket socket = new Socket(address.getAddress(), address.getPort());
if (sslSocketFactory != null) {
// We assume the sslSocketFactory will verify the server hostname.
socket = sslSocketFactory.createSocket(socket, authorityHost, address.getPort(), true);
}
source = Okio.buffer(Okio.source(socket));
sink = Okio.buffer(Okio.sink(socket));
} catch (IOException e) {
throw new RuntimeException(e);
}
Variant variant = new Http20Draft16();
frameReader = variant.newReader(source, true);
frameWriter = new AsyncFrameWriter(variant.newWriter(sink, true), this, executor);
outboundFlow = new OutboundFlowController(this, frameWriter);
frameWriter.connectionPreface();
Settings settings = new Settings();
frameWriter.settings(settings);
}
clientFrameHandler = new ClientFrameHandler();
executor.execute(clientFrameHandler);
}
@Override
public void shutdown() {
boolean normalClose;
synchronized (lock) {
normalClose = !goAway;
}
if (normalClose) {
// Send GOAWAY with lastGoodStreamId of 0, since we don't expect any server-initiated streams.
// The GOAWAY is part of graceful shutdown.
if (frameWriter != null) {
frameWriter.goAway(0, ErrorCode.NO_ERROR, new byte[0]);
}
onGoAway(Integer.MAX_VALUE, Status.INTERNAL.withDescription("Transport stopped"));
}
stopIfNecessary();
}
@VisibleForTesting
ClientFrameHandler getHandler() {
return clientFrameHandler;
}
Map<Integer, OkHttpClientStream> getStreams() {
return streams;
}
/**
* Finish all active streams due to a failure, then close the transport.
*/
void abort(Throwable failureCause) {
log.log(Level.SEVERE, "Transport failed", failureCause);
onGoAway(0, Status.fromThrowable(failureCause));
}
private void onGoAway(int lastKnownStreamId, Status status) {
boolean notifyShutdown;
ArrayList<OkHttpClientStream> goAwayStreams = new ArrayList<OkHttpClientStream>();
synchronized (lock) {
notifyShutdown = !goAway;
goAway = true;
goAwayStatus = status;
synchronized (streams) {
Iterator<Map.Entry<Integer, OkHttpClientStream>> it = streams.entrySet().iterator();
while (it.hasNext()) {
Map.Entry<Integer, OkHttpClientStream> entry = it.next();
if (entry.getKey() > lastKnownStreamId) {
goAwayStreams.add(entry.getValue());
it.remove();
}
}
}
}
if (notifyShutdown) {
listener.transportShutdown();
}
for (OkHttpClientStream stream : goAwayStreams) {
stream.transportReportStatus(status, false, new Metadata.Trailers());
}
stopIfNecessary();
}
/**
* Called when a stream is closed.
*
* <p> Return false if the stream has already finished.
*/
boolean finishStream(int streamId, @Nullable Status status) {
OkHttpClientStream stream;
stream = streams.remove(streamId);
if (stream != null) {
if (status != null) {
boolean isCancelled = status.getCode() == Code.CANCELLED;
stream.transportReportStatus(status, isCancelled, new Metadata.Trailers());
}
return true;
}
return false;
}
/**
* When the transport is in goAway states, we should stop it once all active streams finish.
*/
void stopIfNecessary() {
boolean shouldStop;
synchronized (lock) {
shouldStop = (goAway && streams.size() == 0);
if (shouldStop) {
if (stopped) {
// We've already stopped, don't stop again.
shouldStop = false;
}
stopped = true;
}
}
if (shouldStop) {
// Wait for the frame writer to close.
if (frameWriter != null) {
frameWriter.close();
try {
frameReader.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
listener.transportTerminated();
}
}
/**
* Returns a Grpc status corresponding to the given ErrorCode.
*/
@VisibleForTesting
static Status toGrpcStatus(ErrorCode code) {
return ERROR_CODE_TO_STATUS.get(code);
}
/**
* Runnable which reads frames and dispatches them to in flight calls.
*/
@VisibleForTesting
class ClientFrameHandler implements FrameReader.Handler, Runnable {
ClientFrameHandler() {}
@Override
public void run() {
String threadName = Thread.currentThread().getName();
Thread.currentThread().setName("OkHttpClientTransport");
try {
// Read until the underlying socket closes.
while (frameReader.nextFrame(this)) {
}
} catch (IOException ioe) {
abort(ioe);
} finally {
// Restore the original thread name.
Thread.currentThread().setName(threadName);
}
}
/**
* Handle a HTTP2 DATA frame.
*/
@Override
public void data(boolean inFinished, int streamId, BufferedSource in, int length)
throws IOException {
final OkHttpClientStream stream;
stream = streams.get(streamId);
if (stream == null) {
frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
return;
}
// Wait until the frame is complete.
in.require(length);
Buffer buf = new Buffer();
buf.write(in.buffer(), length);
stream.transportDataReceived(buf, inFinished);
// connection window update
connectionUnacknowledgedBytesRead += length;
if (connectionUnacknowledgedBytesRead >= DEFAULT_INITIAL_WINDOW_SIZE / 2) {
frameWriter.windowUpdate(0, connectionUnacknowledgedBytesRead);
connectionUnacknowledgedBytesRead = 0;
}
}
/**
* Handle HTTP2 HEADER and CONTINUATION frames.
*/
@Override
public void headers(boolean outFinished,
boolean inFinished,
int streamId,
int associatedStreamId,
List<Header> headerBlock,
HeadersMode headersMode) {
OkHttpClientStream stream;
stream = streams.get(streamId);
if (stream == null) {
frameWriter.rstStream(streamId, ErrorCode.INVALID_STREAM);
return;
}
stream.transportHeadersReceived(headerBlock, inFinished);
}
@Override
public void rstStream(int streamId, ErrorCode errorCode) {
if (finishStream(streamId, toGrpcStatus(errorCode))) {
stopIfNecessary();
}
}
@Override
public void settings(boolean clearPrevious, Settings settings) {
// not impl
try {
frameWriter.ackSettings(settings);
} catch (IOException e) {
abort(e);
}
}
@Override
public void ping(boolean ack, int payload1, int payload2) {
if (!ack) {
frameWriter.ping(true, payload1, payload2);
}
}
@Override
public void ackSettings() {
// Do nothing currently.
}
@Override
public void goAway(int lastGoodStreamId, ErrorCode errorCode, ByteString debugData) {
onGoAway(lastGoodStreamId, Status.UNAVAILABLE.withDescription("Go away"));
}
@Override
public void pushPromise(int streamId, int promisedStreamId, List<Header> requestHeaders)
throws IOException {
// We don't accept server initiated stream.
frameWriter.rstStream(streamId, ErrorCode.PROTOCOL_ERROR);
}
@Override
public void windowUpdate(int streamId, long delta) {
outboundFlow.windowUpdate(streamId, (int) delta);
}
@Override
public void priority(int streamId, int streamDependency, int weight, boolean exclusive) {
// Ignore priority change.
// TODO(madongfly): log
}
@Override
public void alternateService(int streamId, String origin, ByteString protocol, String host,
int port, long maxAge) {
// TODO(madongfly): Deal with alternateService propagation
}
}
@GuardedBy("lock")
private void assignStreamId(OkHttpClientStream stream) {
Preconditions.checkState(stream.id() == null, "StreamId already assigned");
stream.id(nextStreamId);
streams.put(stream.id(), stream);
if (nextStreamId >= Integer.MAX_VALUE - 2) {
onGoAway(Integer.MAX_VALUE, Status.INTERNAL.withDescription("Stream ids exhausted"));
} else {
nextStreamId += 2;
}
}
}