Adding default User-Agent for netty and okhttp.
diff --git a/build.gradle b/build.gradle
index 968c5eb..1861a77 100644
--- a/build.gradle
+++ b/build.gradle
@@ -28,11 +28,21 @@
mavenLocal()
}
+
[compileJava, compileTestJava].each() {
it.options.compilerArgs += ["-Xlint:unchecked", "-Xlint:deprecation", "-Xlint:-options"]
it.options.encoding = "UTF-8"
}
+ jar.manifest {
+ attributes('Implementation-Title': name,
+ 'Implementation-Version': version,
+ 'Built-By': System.getProperty('user.name'),
+ 'Built-JDK': System.getProperty('java.version'),
+ 'Source-Compatibility': sourceCompatibility,
+ 'Target-Compatibility': targetCompatibility)
+ }
+
javadoc.options {
encoding = 'UTF-8'
links 'https://docs.oracle.com/javase/8/docs/api/'
diff --git a/core/src/main/java/io/grpc/AbstractChannelBuilder.java b/core/src/main/java/io/grpc/AbstractChannelBuilder.java
index d137e26..4b142cc 100644
--- a/core/src/main/java/io/grpc/AbstractChannelBuilder.java
+++ b/core/src/main/java/io/grpc/AbstractChannelBuilder.java
@@ -71,6 +71,9 @@
@Nullable
private ExecutorService userExecutor;
+ @Nullable
+ private String userAgent;
+
/**
* Provides a custom executor.
*
@@ -87,6 +90,18 @@
}
/**
+ * Provides a custom {@code User-Agent} for the application.
+ *
+ * <p>It's an optional parameter. If provided, the given agent will be prepended by the
+ * grpc {@code User-Agent}.
+ */
+ @SuppressWarnings("unchecked")
+ public final BuilderT userAgent(String userAgent) {
+ this.userAgent = userAgent;
+ return (BuilderT) this;
+ }
+
+ /**
* Builds a channel using the given parameters.
*/
public ChannelImpl build() {
@@ -101,7 +116,7 @@
}
final ChannelEssentials essentials = buildEssentials();
- ChannelImpl channel = new ChannelImpl(essentials.transportFactory, executor);
+ ChannelImpl channel = new ChannelImpl(essentials.transportFactory, executor, userAgent);
channel.setTerminationRunnable(new Runnable() {
@Override
public void run() {
diff --git a/core/src/main/java/io/grpc/ChannelImpl.java b/core/src/main/java/io/grpc/ChannelImpl.java
index 7268955..40cc8ff 100644
--- a/core/src/main/java/io/grpc/ChannelImpl.java
+++ b/core/src/main/java/io/grpc/ChannelImpl.java
@@ -50,6 +50,7 @@
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
+import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
@@ -81,6 +82,8 @@
private final ClientTransportFactory transportFactory;
private final ExecutorService executor;
+ private final String userAgent;
+
/**
* All transports that are not stopped. At the very least {@link #activeTransport} will be
* present, but previously used transports that still have streams or are stopping may also be
@@ -99,9 +102,11 @@
private boolean terminated;
private Runnable terminationRunnable;
- public ChannelImpl(ClientTransportFactory transportFactory, ExecutorService executor) {
+ ChannelImpl(ClientTransportFactory transportFactory, ExecutorService executor,
+ @Nullable String userAgent) {
this.transportFactory = transportFactory;
this.executor = executor;
+ this.userAgent = userAgent;
}
/** Hack to allow executors to auto-shutdown. Not for general use. */
@@ -334,13 +339,18 @@
headers.put(TIMEOUT_KEY, timeoutMicros);
}
+ // Fill out the User-Agent header.
+ headers.removeAll(HttpUtil.USER_AGENT_KEY);
+ if (userAgent != null) {
+ headers.put(HttpUtil.USER_AGENT_KEY, userAgent);
+ }
+
try {
stream = transport.newStream(method, headers, listener);
} catch (IllegalStateException ex) {
// We can race with the transport and end up trying to use a terminated transport.
// TODO(ejona86): Improve the API to remove the possibility of the race.
closeCallPrematurely(listener, Status.fromThrowable(ex));
- return;
}
}
diff --git a/core/src/main/java/io/grpc/Metadata.java b/core/src/main/java/io/grpc/Metadata.java
index 9480fc6..342c0e4 100644
--- a/core/src/main/java/io/grpc/Metadata.java
+++ b/core/src/main/java/io/grpc/Metadata.java
@@ -357,27 +357,27 @@
/**
* Marshaller for metadata values that are serialized into raw binary.
*/
- public static interface BinaryMarshaller<T> {
+ public interface BinaryMarshaller<T> {
/**
* Serialize a metadata value to bytes.
* @param value to serialize
* @return serialized version of value
*/
- public byte[] toBytes(T value);
+ byte[] toBytes(T value);
/**
* Parse a serialized metadata value from bytes.
* @param serialized value of metadata to parse
* @return a parsed instance of type T
*/
- public T parseBytes(byte[] serialized);
+ T parseBytes(byte[] serialized);
}
/**
* Marshaller for metadata values that are serialized into ASCII strings that contain only
* printable characters and space.
*/
- public static interface AsciiMarshaller<T> {
+ public interface AsciiMarshaller<T> {
/**
* Serialize a metadata value to a ASCII string that contains only printable characters and
* space.
@@ -385,14 +385,14 @@
* @param value to serialize
* @return serialized version of value, or null if value cannot be transmitted.
*/
- public String toAsciiString(T value);
+ String toAsciiString(T value);
/**
* Parse a serialized metadata value from an ASCII string.
* @param serialized value of metadata to parse
* @return a parsed instance of type T
*/
- public T parseAsciiString(String serialized);
+ T parseAsciiString(String serialized);
}
/**
diff --git a/core/src/main/java/io/grpc/transport/Http2ClientStream.java b/core/src/main/java/io/grpc/transport/Http2ClientStream.java
index 85577cc..bf4ea43 100644
--- a/core/src/main/java/io/grpc/transport/Http2ClientStream.java
+++ b/core/src/main/java/io/grpc/transport/Http2ClientStream.java
@@ -206,7 +206,7 @@
return null;
}
contentTypeChecked = true;
- String contentType = headers.get(HttpUtil.CONTENT_TYPE);
+ String contentType = headers.get(HttpUtil.CONTENT_TYPE_KEY);
if (TEMP_CHECK_CONTENT_TYPE && !HttpUtil.CONTENT_TYPE_GRPC.equalsIgnoreCase(contentType)) {
// Malformed content-type so report an error
return Status.INTERNAL.withDescription("invalid content-type " + contentType);
@@ -218,7 +218,7 @@
* Inspect the raw metadata and figure out what charset is being used.
*/
private static Charset extractCharset(Metadata headers) {
- String contentType = headers.get(HttpUtil.CONTENT_TYPE);
+ String contentType = headers.get(HttpUtil.CONTENT_TYPE_KEY);
if (contentType != null) {
String[] split = contentType.split("charset=");
try {
diff --git a/core/src/main/java/io/grpc/transport/HttpUtil.java b/core/src/main/java/io/grpc/transport/HttpUtil.java
index 5cf9949..c2222b5 100644
--- a/core/src/main/java/io/grpc/transport/HttpUtil.java
+++ b/core/src/main/java/io/grpc/transport/HttpUtil.java
@@ -36,16 +36,24 @@
import java.net.HttpURLConnection;
+import javax.annotation.Nullable;
+
/**
* Constants for GRPC-over-HTTP (or HTTP/2).
*/
public final class HttpUtil {
+
/**
- * The Content-Type header name. Defined here since it is not explicitly defined by the HTTP/2
- * spec.
+ * {@link Metadata.Key} for the Content-Type request/response header.
*/
- public static final Metadata.Key<String> CONTENT_TYPE =
- Metadata.Key.of("content-type", Metadata.ASCII_STRING_MARSHALLER);
+ public static final Metadata.Key<String> CONTENT_TYPE_KEY =
+ Metadata.Key.of("content-type", Metadata.ASCII_STRING_MARSHALLER);
+
+ /**
+ * {@link Metadata.Key} for the Content-Type request/response header.
+ */
+ public static final Metadata.Key<String> USER_AGENT_KEY =
+ Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER);
/**
* Content-Type used for GRPC-over-HTTP/2.
@@ -58,12 +66,6 @@
public static final String HTTP_METHOD = "POST";
/**
- * The TE header name. Defined here since it is not explicitly defined by the HTTP/2 spec.
- */
- public static final Metadata.Key<String> TE = Metadata.Key.of("te",
- Metadata.ASCII_STRING_MARSHALLER);
-
- /**
* The TE (transport encoding) header for requests over HTTP/2.
*/
public static final String TE_TRAILERS = "trailers";
@@ -136,7 +138,7 @@
private final int code;
private final Status status;
- private Http2Error(int code, Status status) {
+ Http2Error(int code, Status status) {
this.code = code;
this.status = status.augmentDescription("HTTP/2 error code: " + this.name());
}
@@ -191,5 +193,23 @@
}
}
+ /**
+ * Gets the User-Agent string for the gRPC transport.
+ */
+ public static String getGrpcUserAgent(String transportName,
+ @Nullable String applicationUserAgent) {
+ StringBuilder builder = new StringBuilder("grpc-java-").append(transportName);
+ String version = HttpUtil.class.getPackage().getImplementationVersion();
+ if (version != null) {
+ builder.append("/");
+ builder.append(version);
+ }
+ if (applicationUserAgent != null) {
+ builder.append(' ');
+ builder.append(applicationUserAgent);
+ }
+ return builder.toString();
+ }
+
private HttpUtil() {}
}
diff --git a/core/src/test/java/io/grpc/ChannelImplTest.java b/core/src/test/java/io/grpc/ChannelImplTest.java
index 791f8c9..ce70aab 100644
--- a/core/src/test/java/io/grpc/ChannelImplTest.java
+++ b/core/src/test/java/io/grpc/ChannelImplTest.java
@@ -94,7 +94,7 @@
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
- channel = new ChannelImpl(mockTransportFactory, executor);
+ channel = new ChannelImpl(mockTransportFactory, executor, null);
when(mockTransportFactory.newClientTransport()).thenReturn(mockTransport);
}
diff --git a/netty/src/main/java/io/grpc/transport/netty/Utils.java b/netty/src/main/java/io/grpc/transport/netty/Utils.java
index 0a95061..ee22e8e 100644
--- a/netty/src/main/java/io/grpc/transport/netty/Utils.java
+++ b/netty/src/main/java/io/grpc/transport/netty/Utils.java
@@ -31,6 +31,8 @@
package io.grpc.transport.netty;
+import static io.grpc.transport.HttpUtil.CONTENT_TYPE_KEY;
+import static io.grpc.transport.HttpUtil.USER_AGENT_KEY;
import static io.netty.util.CharsetUtil.UTF_8;
import com.google.common.base.Preconditions;
@@ -40,8 +42,7 @@
import io.grpc.SharedResourceHolder.Resource;
import io.grpc.transport.HttpUtil;
import io.grpc.transport.TransportFrameUtil;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.ByteBufAllocator;
+
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.handler.codec.http2.DefaultHttp2Headers;
@@ -50,7 +51,6 @@
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
-import java.nio.ByteBuffer;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -66,12 +66,13 @@
public static final ByteString HTTP_METHOD = new ByteString(HttpUtil.HTTP_METHOD.getBytes(UTF_8));
public static final ByteString HTTPS = new ByteString("https".getBytes(UTF_8));
public static final ByteString HTTP = new ByteString("http".getBytes(UTF_8));
- public static final ByteString CONTENT_TYPE_HEADER = new ByteString(HttpUtil.CONTENT_TYPE.name()
+ public static final ByteString CONTENT_TYPE_HEADER = new ByteString(CONTENT_TYPE_KEY.name()
.getBytes(UTF_8));
public static final ByteString CONTENT_TYPE_GRPC = new ByteString(
HttpUtil.CONTENT_TYPE_GRPC.getBytes(UTF_8));
- public static final ByteString TE_HEADER = new ByteString(HttpUtil.TE.name().getBytes(UTF_8));
+ public static final ByteString TE_HEADER = new ByteString("te".getBytes(UTF_8));
public static final ByteString TE_TRAILERS = new ByteString(HttpUtil.TE_TRAILERS.getBytes(UTF_8));
+ public static final ByteString USER_AGENT = new ByteString(USER_AGENT_KEY.name().getBytes(UTF_8));
public static final Resource<EventLoopGroup> DEFAULT_BOSS_EVENT_LOOP_GROUP =
new DefaultEventLoopGroupResource(1, "grpc-default-boss-ELG");
@@ -79,15 +80,6 @@
public static final Resource<EventLoopGroup> DEFAULT_WORKER_EVENT_LOOP_GROUP =
new DefaultEventLoopGroupResource(0, "grpc-default-worker-ELG");
- /**
- * Copies the content of the given {@link ByteBuffer} to a new {@link ByteBuf} instance.
- */
- static ByteBuf toByteBuf(ByteBufAllocator alloc, ByteBuffer source) {
- ByteBuf buf = alloc.buffer(source.remaining());
- buf.writeBytes(source);
- return buf;
- }
-
public static Metadata.Headers convertHeaders(Http2Headers http2Headers) {
Metadata.Headers headers = new Metadata.Headers(convertHeadersToArray(http2Headers));
if (http2Headers.authority() != null) {
@@ -137,6 +129,10 @@
http2Headers.path(new ByteString(headers.getPath().getBytes(UTF_8)));
}
+ // Set the User-Agent header.
+ String userAgent = HttpUtil.getGrpcUserAgent("netty", headers.get(USER_AGENT_KEY));
+ http2Headers.set(USER_AGENT, new ByteString(userAgent.getBytes(UTF_8)));
+
return http2Headers;
}
@@ -165,9 +161,11 @@
Http2Headers http2Headers = new DefaultHttp2Headers();
byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(headers);
for (int i = 0; i < serializedHeaders.length; i += 2) {
- http2Headers.add(new ByteString(serializedHeaders[i], false),
- new ByteString(serializedHeaders[i + 1], false));
+ ByteString name = new ByteString(serializedHeaders[i], false);
+ ByteString value = new ByteString(serializedHeaders[i + 1], false);
+ http2Headers.add(name, value);
}
+
return http2Headers;
}
diff --git a/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java
index 8c62b39..add9516 100644
--- a/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java
+++ b/netty/src/test/java/io/grpc/transport/netty/NettyClientTransportTest.java
@@ -32,7 +32,9 @@
package io.grpc.transport.netty;
import static com.google.common.base.Charsets.UTF_8;
+import static io.grpc.transport.HttpUtil.USER_AGENT_KEY;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import com.google.common.io.ByteStreams;
@@ -48,6 +50,7 @@
import io.grpc.transport.ClientStream;
import io.grpc.transport.ClientStreamListener;
import io.grpc.transport.ClientTransport;
+import io.grpc.transport.HttpUtil;
import io.grpc.transport.ServerListener;
import io.grpc.transport.ServerStream;
import io.grpc.transport.ServerStreamListener;
@@ -74,6 +77,7 @@
import java.io.InputStream;
import java.net.InetSocketAddress;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@@ -115,6 +119,41 @@
group.shutdownGracefully(0, 10, TimeUnit.SECONDS);
}
+ @Test
+ public void headersShouldAddDefaultUserAgent() throws Exception {
+ startServer();
+ NettyClientTransport transport = newTransport(newNegotiator());
+ transport.start(clientTransportListener);
+
+ // Send a single RPC and wait for the response.
+ new Rpc(transport).halfClose().waitForResponse();
+
+ // Verify that the received headers contained the User-Agent.
+ assertEquals(1, serverListener.streamListeners.size());
+
+ Metadata.Headers headers = serverListener.streamListeners.get(0).headers;
+ assertEquals(HttpUtil.getGrpcUserAgent("netty", null), headers.get(USER_AGENT_KEY));
+ }
+
+ @Test
+ public void headersShouldOverrideDefaultUserAgent() throws Exception {
+ startServer();
+ NettyClientTransport transport = newTransport(newNegotiator());
+ transport.start(clientTransportListener);
+
+ // Send a single RPC and wait for the response.
+ String userAgent = "testUserAgent";
+ Metadata.Headers sentHeaders = new Metadata.Headers();
+ sentHeaders.put(USER_AGENT_KEY, userAgent);
+ new Rpc(transport, sentHeaders).halfClose().waitForResponse();
+
+ // Verify that the received headers contained the User-Agent.
+ assertEquals(1, serverListener.streamListeners.size());
+ Metadata.Headers receivedHeaders = serverListener.streamListeners.get(0).headers;
+ assertEquals(HttpUtil.getGrpcUserAgent("netty", userAgent),
+ receivedHeaders.get(USER_AGENT_KEY));
+ }
+
/**
* Verifies that we can create multiple TLS client transports from the same builder.
*/
@@ -265,8 +304,44 @@
}
}
+ private static final class EchoServerStreamListener implements ServerStreamListener {
+ final ServerStream stream;
+ final String method;
+ final Metadata.Headers headers;
+
+ EchoServerStreamListener(ServerStream stream, String method, Metadata.Headers headers) {
+ this.stream = stream;
+ this.method = method;
+ this.headers = headers;
+ stream.request(1);
+ }
+
+ @Override
+ public void messageRead(InputStream message) {
+ // Just echo back the message.
+ stream.writeMessage(message);
+ stream.flush();
+ }
+
+ @Override
+ public void onReady() {
+ }
+
+ @Override
+ public void halfClosed() {
+ // Just close when the client closes.
+ stream.close(Status.OK, new Metadata.Trailers());
+ }
+
+ @Override
+ public void closed(Status status) {
+ }
+ }
+
private static class EchoServerListener implements ServerListener {
final List<NettyServerTransport> transports = new ArrayList<NettyServerTransport>();
+ final List<EchoServerStreamListener> streamListeners =
+ Collections.synchronizedList(new ArrayList<EchoServerStreamListener>());
@Override
public ServerTransportListener transportCreated(final ServerTransport transport) {
@@ -276,30 +351,9 @@
@Override
public ServerStreamListener streamCreated(final ServerStream stream, String method,
Metadata.Headers headers) {
- stream.request(1);
- return new ServerStreamListener() {
-
- @Override
- public void messageRead(InputStream message) {
- // Just echo back the message.
- stream.writeMessage(message);
- stream.flush();
- }
-
- @Override
- public void onReady() {
- }
-
- @Override
- public void halfClosed() {
- // Just close when the client closes.
- stream.close(Status.OK, new Metadata.Trailers());
- }
-
- @Override
- public void closed(Status status) {
- }
- };
+ EchoServerStreamListener listener = new EchoServerStreamListener(stream, method, headers);
+ streamListeners.add(listener);
+ return listener;
}
@Override
diff --git a/okhttp/src/main/java/io/grpc/transport/okhttp/Headers.java b/okhttp/src/main/java/io/grpc/transport/okhttp/Headers.java
index 857fb29..addf596 100644
--- a/okhttp/src/main/java/io/grpc/transport/okhttp/Headers.java
+++ b/okhttp/src/main/java/io/grpc/transport/okhttp/Headers.java
@@ -31,6 +31,9 @@
package io.grpc.transport.okhttp;
+import static io.grpc.transport.HttpUtil.CONTENT_TYPE_KEY;
+import static io.grpc.transport.HttpUtil.USER_AGENT_KEY;
+
import com.google.common.base.Preconditions;
import com.squareup.okhttp.internal.spdy.Header;
@@ -52,8 +55,8 @@
public static final Header SCHEME_HEADER = new Header(Header.TARGET_SCHEME, "https");
public static final Header METHOD_HEADER = new Header(Header.TARGET_METHOD, HttpUtil.HTTP_METHOD);
public static final Header CONTENT_TYPE_HEADER =
- new Header(HttpUtil.CONTENT_TYPE.name(), HttpUtil.CONTENT_TYPE_GRPC);
- public static final Header TE_HEADER = new Header(HttpUtil.TE.name(), HttpUtil.TE_TRAILERS);
+ new Header(CONTENT_TYPE_KEY.name(), HttpUtil.CONTENT_TYPE_GRPC);
+ public static final Header TE_HEADER = new Header("te", HttpUtil.TE_TRAILERS);
/**
* Serializes the given headers and creates a list of OkHttp {@link Header}s to be used when
@@ -76,6 +79,9 @@
String path = headers.getPath() != null ? headers.getPath() : defaultPath;
okhttpHeaders.add(new Header(Header.TARGET_PATH, path));
+ String userAgent = HttpUtil.getGrpcUserAgent("okhttp", headers.get(USER_AGENT_KEY));
+ okhttpHeaders.add(new Header(HttpUtil.USER_AGENT_KEY.name(), userAgent));
+
// All non-pseudo headers must come after pseudo headers.
okhttpHeaders.add(CONTENT_TYPE_HEADER);
okhttpHeaders.add(TE_HEADER);
@@ -84,8 +90,9 @@
byte[][] serializedHeaders = TransportFrameUtil.toHttp2Headers(headers);
for (int i = 0; i < serializedHeaders.length; i += 2) {
ByteString key = ByteString.of(serializedHeaders[i]);
- ByteString value = ByteString.of(serializedHeaders[i + 1]);
- if (isApplicationHeader(key)) {
+ String keyString = key.utf8();
+ if (isApplicationHeader(keyString)) {
+ ByteString value = ByteString.of(serializedHeaders[i + 1]);
okhttpHeaders.add(new Header(key, value));
}
}
@@ -97,10 +104,10 @@
* Returns {@code true} if the given header is an application-provided header. Otherwise, returns
* {@code false} if the header is reserved by GRPC.
*/
- private static boolean isApplicationHeader(ByteString key) {
- String keyString = key.utf8();
+ private static boolean isApplicationHeader(String key) {
// Don't allow HTTP/2 pseudo headers or content-type to be added by the application.
- return (!keyString.startsWith(":")
- && !HttpUtil.CONTENT_TYPE.name().equalsIgnoreCase(keyString));
+ return (!key.startsWith(":")
+ && !CONTENT_TYPE_KEY.name().equalsIgnoreCase(key))
+ && !USER_AGENT_KEY.name().equalsIgnoreCase(key);
}
}
diff --git a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java
index 1b2131b..2b80da1 100644
--- a/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java
+++ b/okhttp/src/test/java/io/grpc/transport/okhttp/OkHttpClientTransportTest.java
@@ -32,6 +32,10 @@
package io.grpc.transport.okhttp;
import static com.google.common.base.Charsets.UTF_8;
+import static io.grpc.transport.okhttp.Headers.CONTENT_TYPE_HEADER;
+import static io.grpc.transport.okhttp.Headers.METHOD_HEADER;
+import static io.grpc.transport.okhttp.Headers.SCHEME_HEADER;
+import static io.grpc.transport.okhttp.Headers.TE_HEADER;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
@@ -68,6 +72,7 @@
import io.grpc.StatusException;
import io.grpc.transport.ClientStreamListener;
import io.grpc.transport.ClientTransport;
+import io.grpc.transport.HttpUtil;
import io.grpc.transport.okhttp.OkHttpClientTransport.ClientFrameHandler;
import okio.Buffer;
@@ -89,6 +94,7 @@
import java.io.InputStreamReader;
import java.net.Socket;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -115,7 +121,6 @@
private ClientTransport.Listener transportListener;
private OkHttpClientTransport clientTransport;
private MockFrameReader frameReader;
- private MockSocket socket;
private Map<Integer, OkHttpClientStream> streams;
private ClientFrameHandler frameHandler;
private ExecutorService executor;
@@ -127,7 +132,7 @@
MockitoAnnotations.initMocks(this);
streams = new HashMap<Integer, OkHttpClientStream>();
frameReader = new MockFrameReader();
- socket = new MockSocket(frameReader);
+ MockSocket socket = new MockSocket(frameReader);
executor = Executors.newCachedThreadPool();
Ticker ticker = new Ticker() {
@Override
@@ -206,7 +211,7 @@
public void receivedHeadersForInvalidStreamShouldKillConnection() throws Exception {
// Empty headers block without correct content type or status
frameHandler.headers(false, false, 3, 0, new ArrayList<Header>(),
- HeadersMode.HTTP_20_HEADERS);
+ HeadersMode.HTTP_20_HEADERS);
verify(frameWriter).goAway(eq(0), eq(ErrorCode.PROTOCOL_ERROR), any(byte[].class));
verify(transportListener).transportShutdown();
verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated();
@@ -271,6 +276,37 @@
}
@Test
+ public void headersShouldAddDefaultUserAgent() throws Exception {
+ MockStreamListener listener = new MockStreamListener();
+ clientTransport.newStream(method, new Metadata.Headers(), listener);
+ Header userAgentHeader = new Header(HttpUtil.USER_AGENT_KEY.name(),
+ HttpUtil.getGrpcUserAgent("okhttp", null));
+ List<Header> expectedHeaders = Arrays.asList(SCHEME_HEADER, METHOD_HEADER,
+ new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"),
+ new Header(Header.TARGET_PATH, "/fakemethod"),
+ userAgentHeader, CONTENT_TYPE_HEADER, TE_HEADER);
+ verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders));
+ streams.get(3).cancel();
+ }
+
+ @Test
+ public void headersShouldOverrideDefaultUserAgent() throws Exception {
+ MockStreamListener listener = new MockStreamListener();
+ String userAgent = "fakeUserAgent";
+ Metadata.Headers metadata = new Metadata.Headers();
+ metadata.put(HttpUtil.USER_AGENT_KEY, userAgent);
+ clientTransport.newStream(method, metadata, listener);
+ List<Header> expectedHeaders = Arrays.asList(SCHEME_HEADER, METHOD_HEADER,
+ new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"),
+ new Header(Header.TARGET_PATH, "/fakemethod"),
+ new Header(HttpUtil.USER_AGENT_KEY.name(),
+ HttpUtil.getGrpcUserAgent("okhttp", userAgent)),
+ CONTENT_TYPE_HEADER, TE_HEADER);
+ verify(frameWriter).synStream(eq(false), eq(false), eq(3), eq(0), eq(expectedHeaders));
+ streams.get(3).cancel();
+ }
+
+ @Test
public void writeMessage() throws Exception {
final String message = "Hello Server";
MockStreamListener listener = new MockStreamListener();
@@ -1013,7 +1049,7 @@
private List<Header> grpcResponseHeaders() {
return ImmutableList.<Header>builder()
.add(new Header(":status", "200"))
- .add(Headers.CONTENT_TYPE_HEADER)
+ .add(CONTENT_TYPE_HEADER)
.build();
}