core: integrate instrumentation-java (Census) tracing (#2938)

Main implementation is in CensusTracingModule.

Also a few fix-ups in the stats implementation CensusStatsModule:

- Change header key name from grpc-census-bin to grpc-tags-bin
- Server does not fail on header parse errors. Uses the default instead.

Protect Census-based stats and tracing with static flags: `GrpcUtil.enableCensusStats` and `GrpcUtil.enableCensusTracing`. They keep those features disabled by default until they, especially their wire formats, are stabilized.
diff --git a/core/src/test/java/io/grpc/internal/CensusModulesTest.java b/core/src/test/java/io/grpc/internal/CensusModulesTest.java
new file mode 100644
index 0000000..23e3a00
--- /dev/null
+++ b/core/src/test/java/io/grpc/internal/CensusModulesTest.java
@@ -0,0 +1,728 @@
+/*
+ * Copyright 2017, Google Inc. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are
+ * met:
+ *
+ *    * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ *    * Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following disclaimer
+ * in the documentation and/or other materials provided with the
+ * distribution.
+ *
+ *    * Neither the name of Google Inc. nor the names of its
+ * contributors may be used to endorse or promote products derived from
+ * this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+package io.grpc.internal;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Matchers.isNull;
+import static org.mockito.Matchers.same;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
+
+import com.google.instrumentation.stats.RpcConstants;
+import com.google.instrumentation.stats.StatsContext;
+import com.google.instrumentation.stats.TagValue;
+import com.google.instrumentation.trace.Annotation;
+import com.google.instrumentation.trace.AttributeValue;
+import com.google.instrumentation.trace.BinaryPropagationHandler;
+import com.google.instrumentation.trace.ContextUtils;
+import com.google.instrumentation.trace.EndSpanOptions;
+import com.google.instrumentation.trace.Link;
+import com.google.instrumentation.trace.NetworkEvent;
+import com.google.instrumentation.trace.Span;
+import com.google.instrumentation.trace.SpanContext;
+import com.google.instrumentation.trace.SpanFactory;
+import com.google.instrumentation.trace.SpanId;
+import com.google.instrumentation.trace.StartSpanOptions;
+import com.google.instrumentation.trace.TraceId;
+import com.google.instrumentation.trace.TraceOptions;
+import com.google.instrumentation.trace.Tracer;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ClientInterceptors;
+import io.grpc.ClientStreamTracer;
+import io.grpc.Context;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerServiceDefinition;
+import io.grpc.ServerStreamTracer;
+import io.grpc.Status;
+import io.grpc.internal.testing.StatsTestUtils;
+import io.grpc.internal.testing.StatsTestUtils.FakeStatsContextFactory;
+import io.grpc.testing.GrpcServerRule;
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.text.ParseException;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicReference;
+import javax.annotation.Nullable;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+/**
+ * Test for {@link CensusStatsModule} and {@link CensusTracingModule}.
+ */
+@RunWith(JUnit4.class)
+public class CensusModulesTest {
+  private static final CallOptions.Key<String> CUSTOM_OPTION =
+      CallOptions.Key.of("option1", "default");
+  private static final CallOptions CALL_OPTIONS =
+      CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue");
+
+  private static class StringInputStream extends InputStream {
+    final String string;
+
+    StringInputStream(String string) {
+      this.string = string;
+    }
+
+    @Override
+    public int read() {
+      // InProcessTransport doesn't actually read bytes from the InputStream.  The InputStream is
+      // passed to the InProcess server and consumed by MARSHALLER.parse().
+      throw new UnsupportedOperationException("Should not be called");
+    }
+  }
+
+  private static final MethodDescriptor.Marshaller<String> MARSHALLER =
+      new MethodDescriptor.Marshaller<String>() {
+        @Override
+        public InputStream stream(String value) {
+          return new StringInputStream(value);
+        }
+
+        @Override
+        public String parse(InputStream stream) {
+          return ((StringInputStream) stream).string;
+        }
+      };
+
+  private final MethodDescriptor<String, String> method = MethodDescriptor.create(
+      MethodDescriptor.MethodType.UNKNOWN, "package1.service2/method3",
+      MARSHALLER, MARSHALLER);
+  private final FakeClock fakeClock = new FakeClock();
+  private final FakeStatsContextFactory statsCtxFactory = new FakeStatsContextFactory();
+  private final Random random = new Random(0);
+  private final SpanContext fakeClientSpanContext =
+      SpanContext.create(
+          TraceId.generateRandomId(random), SpanId.generateRandomId(random),
+          TraceOptions.builder().build());
+  private final SpanContext fakeClientParentSpanContext =
+      SpanContext.create(
+          TraceId.generateRandomId(random), SpanId.generateRandomId(random),
+          TraceOptions.builder().build());
+  private final SpanContext fakeServerSpanContext =
+      SpanContext.create(
+          TraceId.generateRandomId(random), SpanId.generateRandomId(random),
+          TraceOptions.builder().build());
+  private final SpanContext fakeServerParentSpanContext =
+      SpanContext.create(
+          TraceId.generateRandomId(random), SpanId.generateRandomId(random),
+          TraceOptions.builder().build());
+  private final Span fakeClientSpan = new FakeSpan(fakeClientSpanContext);
+  private final Span fakeServerSpan = new FakeSpan(fakeServerSpanContext);
+  private final Span fakeClientParentSpan = new FakeSpan(fakeClientParentSpanContext);
+  private final Span fakeServerParentSpan = new FakeSpan(fakeServerParentSpanContext);
+  private final Span spyClientSpan = spy(fakeClientSpan);
+  private final Span spyServerSpan = spy(fakeServerSpan);
+  private final byte[] binarySpanContext = new byte[]{3, 1, 5};
+
+  @Rule
+  public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor();
+
+  @Mock
+  private AccessibleSpanFactory mockSpanFactory;
+  @Mock
+  private BinaryPropagationHandler mockTracingPropagationHandler;
+  @Mock
+  private ClientCall.Listener<String> mockClientCallListener;
+  @Mock
+  private ServerCall.Listener<String> mockServerCallListener;
+  @Captor
+  private ArgumentCaptor<CallOptions> callOptionsCaptor;
+  @Captor
+  private ArgumentCaptor<ClientCall.Listener<String>> clientCallListenerCaptor;
+  @Captor
+  private ArgumentCaptor<Status> statusCaptor;
+
+  private Tracer tracer;
+  private CensusStatsModule censusStats;
+  private CensusTracingModule censusTracing;
+
+  @Before
+  @SuppressWarnings("unchecked")
+  public void setUp() throws Exception {
+    MockitoAnnotations.initMocks(this);
+    when(mockSpanFactory.startSpan(any(Span.class), anyString(), any(StartSpanOptions.class)))
+        .thenReturn(spyClientSpan);
+    when(
+        mockSpanFactory.startSpanWithRemoteParent(
+            any(SpanContext.class), anyString(), any(StartSpanOptions.class)))
+        .thenReturn(spyServerSpan);
+    when(mockTracingPropagationHandler.toBinaryValue(any(SpanContext.class)))
+        .thenReturn(binarySpanContext);
+    when(mockTracingPropagationHandler.fromBinaryValue(any(byte[].class)))
+        .thenReturn(fakeServerParentSpanContext);
+    tracer = new Tracer(mockSpanFactory) {};
+    censusStats = new CensusStatsModule(statsCtxFactory, fakeClock.getStopwatchSupplier());
+    censusTracing = new CensusTracingModule(tracer, mockTracingPropagationHandler);
+  }
+
+  @After
+  public void wrapUp() {
+    assertNull(statsCtxFactory.pollRecord());
+  }
+
+  @Test
+  public void clientInterceptorNoCustomTag() {
+    testClientInterceptors(false);
+  }
+
+  @Test
+  public void clientInterceptorCustomTag() {
+    testClientInterceptors(true);
+  }
+
+  // Test that Census ClientInterceptors uses the StatsContext and Span out of the current Context
+  // to create the ClientCallTracer, and that it intercepts ClientCall.Listener.onClose() to call
+  // ClientCallTracer.callEnded().
+  private void testClientInterceptors(boolean nonDefaultContext) {
+    grpcServerRule.getServiceRegistry().addService(
+        ServerServiceDefinition.builder("package1.service2").addMethod(
+            method, new ServerCallHandler<String, String>() {
+                @Override
+                public ServerCall.Listener<String> startCall(
+                    ServerCall<String, String> call, Metadata headers) {
+                  call.sendHeaders(new Metadata());
+                  call.sendMessage("Hello");
+                  call.close(
+                      Status.PERMISSION_DENIED.withDescription("No you don't"), new Metadata());
+                  return mockServerCallListener;
+                }
+              }).build());
+
+    final AtomicReference<CallOptions> capturedCallOptions = new AtomicReference<CallOptions>();
+    ClientInterceptor callOptionsCaptureInterceptor = new ClientInterceptor() {
+        @Override
+        public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+            MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+          capturedCallOptions.set(callOptions);
+          return next.newCall(method, callOptions);
+        }
+      };
+    Channel interceptedChannel =
+        ClientInterceptors.intercept(
+            grpcServerRule.getChannel(), callOptionsCaptureInterceptor,
+            censusStats.getClientInterceptor(), censusTracing.getClientInterceptor());
+    ClientCall<String, String> call;
+    if (nonDefaultContext) {
+      Context ctx =
+          Context.ROOT.withValues(
+              CensusStatsModule.STATS_CONTEXT_KEY,
+              statsCtxFactory.getDefault().with(
+                  StatsTestUtils.EXTRA_TAG, TagValue.create("extra value")),
+              ContextUtils.CONTEXT_SPAN_KEY,
+              fakeClientParentSpan);
+      Context origCtx = ctx.attach();
+      try {
+        call = interceptedChannel.newCall(method, CALL_OPTIONS);
+      } finally {
+        ctx.detach(origCtx);
+      }
+    } else {
+      assertNull(CensusStatsModule.STATS_CONTEXT_KEY.get());
+      assertNull(ContextUtils.CONTEXT_SPAN_KEY.get());
+      call = interceptedChannel.newCall(method, CALL_OPTIONS);
+    }
+
+    // The interceptor adds tracer factory to CallOptions
+    assertEquals("customvalue", capturedCallOptions.get().getOption(CUSTOM_OPTION));
+    assertEquals(2, capturedCallOptions.get().getStreamTracerFactories().size());
+    assertTrue(
+        capturedCallOptions.get().getStreamTracerFactories().get(0)
+        instanceof CensusTracingModule.ClientCallTracer);
+    assertTrue(
+        capturedCallOptions.get().getStreamTracerFactories().get(1)
+        instanceof CensusStatsModule.ClientCallTracer);
+
+    // Make the call
+    Metadata headers = new Metadata();
+    call.start(mockClientCallListener, headers);
+    assertNull(statsCtxFactory.pollRecord());
+    if (nonDefaultContext) {
+      verify(mockSpanFactory).startSpan(
+          same(fakeClientParentSpan), eq("Sent.package1.service2.method3"),
+          any(StartSpanOptions.class));
+    } else {
+      verify(mockSpanFactory).startSpan(
+          isNull(Span.class), eq("Sent.package1.service2.method3"), any(StartSpanOptions.class));
+    }
+    verify(spyClientSpan, never()).end(any(EndSpanOptions.class));
+
+    // End the call
+    call.halfClose();
+    call.request(1);
+
+    verify(mockClientCallListener).onClose(statusCaptor.capture(), any(Metadata.class));
+    Status status = statusCaptor.getValue();
+    assertEquals(Status.Code.PERMISSION_DENIED, status.getCode());
+    assertEquals("No you don't", status.getDescription());
+
+    // The intercepting listener calls callEnded() on ClientCallTracer, which records to Census.
+    StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord();
+    assertNotNull(record);
+    TagValue methodTag = record.tags.get(RpcConstants.RPC_CLIENT_METHOD);
+    assertEquals(method.getFullMethodName(), methodTag.toString());
+    TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
+    assertEquals(Status.Code.PERMISSION_DENIED.toString(), statusTag.toString());
+    if (nonDefaultContext) {
+      TagValue extraTag = record.tags.get(StatsTestUtils.EXTRA_TAG);
+      assertEquals("extra value", extraTag.toString());
+    } else {
+      assertNull(record.tags.get(StatsTestUtils.EXTRA_TAG));
+    }
+    verify(spyClientSpan).end(
+        EndSpanOptions.builder()
+            .setStatus(
+                com.google.instrumentation.trace.Status.PERMISSION_DENIED
+                    .withDescription("No you don't"))
+            .build());
+    verify(spyClientSpan, never()).end();
+  }
+
+  @Test
+  public void clientBasicStatsDefaultContext() {
+    CensusStatsModule.ClientCallTracer callTracer =
+        censusStats.newClientCallTracer(statsCtxFactory.getDefault(), method.getFullMethodName());
+    Metadata headers = new Metadata();
+    ClientStreamTracer tracer = callTracer.newClientStreamTracer(headers);
+
+    fakeClock.forwardTime(30, MILLISECONDS);
+    tracer.outboundHeaders();
+
+    fakeClock.forwardTime(100, MILLISECONDS);
+    tracer.outboundWireSize(1028);
+    tracer.outboundUncompressedSize(1128);
+
+    fakeClock.forwardTime(16, MILLISECONDS);
+    tracer.inboundWireSize(33);
+    tracer.inboundUncompressedSize(67);
+    tracer.outboundWireSize(99);
+    tracer.outboundUncompressedSize(865);
+
+    fakeClock.forwardTime(24, MILLISECONDS);
+    tracer.inboundWireSize(154);
+    tracer.inboundUncompressedSize(552);
+    tracer.streamClosed(Status.OK);
+    callTracer.callEnded(Status.OK);
+
+    StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord();
+    assertNotNull(record);
+    assertNoServerContent(record);
+    TagValue methodTag = record.tags.get(RpcConstants.RPC_CLIENT_METHOD);
+    assertEquals(method.getFullMethodName(), methodTag.toString());
+    TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
+    assertEquals(Status.Code.OK.toString(), statusTag.toString());
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_ERROR_COUNT));
+    assertEquals(1028 + 99, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_REQUEST_BYTES));
+    assertEquals(1128 + 865,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
+    assertEquals(33 + 154, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
+    assertEquals(67 + 552,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
+    assertEquals(30 + 100 + 16 + 24,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY));
+  }
+
+  @Test
+  public void clientBasicTracingDefaultSpan() {
+    CensusTracingModule.ClientCallTracer callTracer =
+        censusTracing.newClientCallTracer(null, method.getFullMethodName());
+    Metadata headers = new Metadata();
+    ClientStreamTracer tracer = callTracer.newClientStreamTracer(headers);
+    verify(mockSpanFactory).startSpan(
+        isNull(Span.class), eq("Sent.package1.service2.method3"), any(StartSpanOptions.class));
+    verify(spyClientSpan, never()).end(any(EndSpanOptions.class));
+
+    tracer.streamClosed(Status.OK);
+    callTracer.callEnded(Status.OK);
+
+    verify(spyClientSpan).end(
+        EndSpanOptions.builder().setStatus(com.google.instrumentation.trace.Status.OK).build());
+    verifyNoMoreInteractions(mockSpanFactory);
+  }
+
+  @Test
+  public void clientStreamNeverCreatedStillRecordStats() {
+    CensusStatsModule.ClientCallTracer callTracer =
+        censusStats.newClientCallTracer(
+            statsCtxFactory.getDefault(), method.getFullMethodName());
+
+    fakeClock.forwardTime(3000, MILLISECONDS);
+    callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds"));
+
+    StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord();
+    assertNotNull(record);
+    assertNoServerContent(record);
+    TagValue methodTag = record.tags.get(RpcConstants.RPC_CLIENT_METHOD);
+    assertEquals(method.getFullMethodName(), methodTag.toString());
+    TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
+    assertEquals(Status.Code.DEADLINE_EXCEEDED.toString(), statusTag.toString());
+    assertEquals(1, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_ERROR_COUNT));
+    assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_REQUEST_BYTES));
+    assertEquals(0,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
+    assertEquals(0, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
+    assertEquals(0,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
+    assertEquals(3000, record.getMetricAsLongOrFail(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY));
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_SERVER_ELAPSED_TIME));
+  }
+
+  @Test
+  public void clientStreamNeverCreatedStillRecordTracing() {
+    CensusTracingModule.ClientCallTracer callTracer =
+        censusTracing.newClientCallTracer(fakeClientParentSpan, method.getFullMethodName());
+    verify(mockSpanFactory).startSpan(
+        same(fakeClientParentSpan), eq("Sent.package1.service2.method3"),
+        any(StartSpanOptions.class));
+
+    callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds"));
+    verify(spyClientSpan).end(
+        EndSpanOptions.builder()
+            .setStatus(
+                com.google.instrumentation.trace.Status.DEADLINE_EXCEEDED
+                    .withDescription("3 seconds"))
+            .build());
+    verify(spyClientSpan, never()).end();
+  }
+
+  @Test
+  public void statsHeadersPropagateTags() {
+    // EXTRA_TAG is propagated by the FakeStatsContextFactory. Note that not all tags are
+    // propagated.  The StatsContextFactory decides which tags are to propagated.  gRPC facilitates
+    // the propagation by putting them in the headers.
+    StatsContext clientCtx = statsCtxFactory.getDefault().with(
+        StatsTestUtils.EXTRA_TAG, TagValue.create("extra-tag-value-897"));
+    CensusStatsModule.ClientCallTracer callTracer =
+        censusStats.newClientCallTracer(clientCtx, method.getFullMethodName());
+    Metadata headers = new Metadata();
+    // This propagates clientCtx to headers
+    callTracer.newClientStreamTracer(headers);
+    assertTrue(headers.containsKey(censusStats.statsHeader));
+
+    ServerStreamTracer serverTracer =
+        censusStats.getServerTracerFactory().newServerStreamTracer(
+            method.getFullMethodName(), headers);
+    // Server tracer deserializes clientCtx from the headers, so that it records stats with the
+    // propagated tags.
+    Context serverContext = serverTracer.filterContext(Context.ROOT);
+    // It also put clientCtx in the Context seen by the call handler
+    assertEquals(clientCtx, CensusStatsModule.STATS_CONTEXT_KEY.get(serverContext));
+
+
+    // Verifies that the server tracer records the status with the propagated tag
+    serverTracer.streamClosed(Status.OK);
+
+    StatsTestUtils.MetricsRecord serverRecord = statsCtxFactory.pollRecord();
+    assertNotNull(serverRecord);
+    assertNoClientContent(serverRecord);
+    TagValue serverMethodTag = serverRecord.tags.get(RpcConstants.RPC_SERVER_METHOD);
+    assertEquals(method.getFullMethodName(), serverMethodTag.toString());
+    TagValue serverStatusTag = serverRecord.tags.get(RpcConstants.RPC_STATUS);
+    assertEquals(Status.Code.OK.toString(), serverStatusTag.toString());
+    assertNull(serverRecord.getMetric(RpcConstants.RPC_SERVER_ERROR_COUNT));
+    TagValue serverPropagatedTag = serverRecord.tags.get(StatsTestUtils.EXTRA_TAG);
+    assertEquals("extra-tag-value-897", serverPropagatedTag.toString());
+
+    // Verifies that the client tracer factory uses clientCtx, which includes the custom tags, to
+    // record stats.
+    callTracer.callEnded(Status.OK);
+
+    StatsTestUtils.MetricsRecord clientRecord = statsCtxFactory.pollRecord();
+    assertNotNull(clientRecord);
+    assertNoServerContent(clientRecord);
+    TagValue clientMethodTag = clientRecord.tags.get(RpcConstants.RPC_CLIENT_METHOD);
+    assertEquals(method.getFullMethodName(), clientMethodTag.toString());
+    TagValue clientStatusTag = clientRecord.tags.get(RpcConstants.RPC_STATUS);
+    assertEquals(Status.Code.OK.toString(), clientStatusTag.toString());
+    assertNull(clientRecord.getMetric(RpcConstants.RPC_CLIENT_ERROR_COUNT));
+    TagValue clientPropagatedTag = clientRecord.tags.get(StatsTestUtils.EXTRA_TAG);
+    assertEquals("extra-tag-value-897", clientPropagatedTag.toString());
+  }
+
+  @Test
+  public void statsHeadersNotPropagateDefaultContext() {
+    CensusStatsModule.ClientCallTracer callTracer =
+        censusStats.newClientCallTracer(statsCtxFactory.getDefault(), method.getFullMethodName());
+    Metadata headers = new Metadata();
+    callTracer.newClientStreamTracer(headers);
+    assertFalse(headers.containsKey(censusStats.statsHeader));
+  }
+
+  @Test
+  public void statsHeaderMalformed() {
+    // Construct a malformed header and make sure parsing it will throw
+    byte[] statsHeaderValue = new byte[]{1};
+    Metadata.Key<byte[]> arbitraryStatsHeader =
+        Metadata.Key.of("grpc-tags-bin", Metadata.BINARY_BYTE_MARSHALLER);
+    try {
+      statsCtxFactory.deserialize(new ByteArrayInputStream(statsHeaderValue));
+      fail("Should have thrown");
+    } catch (Exception e) {
+      // Expected
+    }
+
+    // But the header key will return a default context for it
+    Metadata headers = new Metadata();
+    assertNull(headers.get(censusStats.statsHeader));
+    headers.put(arbitraryStatsHeader, statsHeaderValue);
+    assertSame(statsCtxFactory.getDefault(), headers.get(censusStats.statsHeader));
+  }
+
+  @Test
+  public void traceHeadersPropagateSpanContext() throws Exception {
+    CensusTracingModule.ClientCallTracer callTracer =
+        censusTracing.newClientCallTracer(fakeClientParentSpan, method.getFullMethodName());
+    Metadata headers = new Metadata();
+    callTracer.newClientStreamTracer(headers);
+
+    verify(mockTracingPropagationHandler).toBinaryValue(same(fakeClientSpanContext));
+    verifyNoMoreInteractions(mockTracingPropagationHandler);
+    verify(mockSpanFactory).startSpan(
+        same(fakeClientParentSpan), eq("Sent.package1.service2.method3"),
+        any(StartSpanOptions.class));
+    verifyNoMoreInteractions(mockSpanFactory);
+    assertTrue(headers.containsKey(censusTracing.tracingHeader));
+
+    ServerStreamTracer serverTracer =
+        censusTracing.getServerTracerFactory().newServerStreamTracer(
+            method.getFullMethodName(), headers);
+    verify(mockTracingPropagationHandler).fromBinaryValue(same(binarySpanContext));
+    verify(mockSpanFactory).startSpanWithRemoteParent(
+        same(fakeServerParentSpanContext), eq("Recv.package1.service2.method3"),
+        any(StartSpanOptions.class));
+
+    Context filteredContext = serverTracer.filterContext(Context.ROOT);
+    assertSame(spyServerSpan, ContextUtils.CONTEXT_SPAN_KEY.get(filteredContext));
+  }
+
+  @Test
+  public void traceHeaderMalformed() throws Exception {
+    // As comparison, normal header parsing
+    Metadata headers = new Metadata();
+    headers.put(censusTracing.tracingHeader, fakeClientSpanContext);
+    // mockTracingPropagationHandler was stubbed to always return fakeServerParentSpanContext
+    assertSame(fakeServerParentSpanContext, headers.get(censusTracing.tracingHeader));
+
+    // Make BinaryPropagationHandler always throw when parsing the header
+    when(mockTracingPropagationHandler.fromBinaryValue(any(byte[].class)))
+        .thenThrow(new ParseException("Malformed header", 0));
+
+    headers = new Metadata();
+    assertNull(headers.get(censusTracing.tracingHeader));
+    headers.put(censusTracing.tracingHeader, fakeClientSpanContext);
+    assertSame(SpanContext.INVALID, headers.get(censusTracing.tracingHeader));
+    assertNotSame(fakeServerParentSpanContext, SpanContext.INVALID);
+
+    // A null Span is used as the parent in this case
+    censusTracing.getServerTracerFactory().newServerStreamTracer(
+        method.getFullMethodName(), headers);
+    verify(mockSpanFactory).startSpanWithRemoteParent(
+        isNull(SpanContext.class), eq("Recv.package1.service2.method3"),
+        any(StartSpanOptions.class));
+  }
+
+  @Test
+  public void serverBasicStatsNoHeaders() {
+    ServerStreamTracer.Factory tracerFactory = censusStats.getServerTracerFactory();
+    ServerStreamTracer tracer =
+        tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata());
+
+    Context filteredContext = tracer.filterContext(Context.ROOT);
+    assertNull(CensusStatsModule.STATS_CONTEXT_KEY.get(filteredContext));
+
+    tracer.inboundWireSize(34);
+    tracer.inboundUncompressedSize(67);
+
+    fakeClock.forwardTime(100, MILLISECONDS);
+    tracer.outboundWireSize(1028);
+    tracer.outboundUncompressedSize(1128);
+
+    fakeClock.forwardTime(16, MILLISECONDS);
+    tracer.inboundWireSize(154);
+    tracer.inboundUncompressedSize(552);
+    tracer.outboundWireSize(99);
+    tracer.outboundUncompressedSize(865);
+
+    fakeClock.forwardTime(24, MILLISECONDS);
+
+    tracer.streamClosed(Status.CANCELLED);
+
+    StatsTestUtils.MetricsRecord record = statsCtxFactory.pollRecord();
+    assertNotNull(record);
+    assertNoClientContent(record);
+    TagValue methodTag = record.tags.get(RpcConstants.RPC_SERVER_METHOD);
+    assertEquals(method.getFullMethodName(), methodTag.toString());
+    TagValue statusTag = record.tags.get(RpcConstants.RPC_STATUS);
+    assertEquals(Status.Code.CANCELLED.toString(), statusTag.toString());
+    assertEquals(1, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_ERROR_COUNT));
+    assertEquals(1028 + 99, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_RESPONSE_BYTES));
+    assertEquals(1128 + 865,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
+    assertEquals(34 + 154, record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_REQUEST_BYTES));
+    assertEquals(67 + 552,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
+    assertEquals(100 + 16 + 24,
+        record.getMetricAsLongOrFail(RpcConstants.RPC_SERVER_SERVER_LATENCY));
+  }
+
+  @Test
+  public void serverBasicTracingNoHeaders() {
+    ServerStreamTracer.Factory tracerFactory = censusTracing.getServerTracerFactory();
+    ServerStreamTracer tracer =
+        tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata());
+    verifyZeroInteractions(mockTracingPropagationHandler);
+    verify(mockSpanFactory).startSpanWithRemoteParent(
+        isNull(SpanContext.class), eq("Recv.package1.service2.method3"),
+        any(StartSpanOptions.class));
+
+    Context filteredContext = tracer.filterContext(Context.ROOT);
+    assertSame(spyServerSpan, ContextUtils.CONTEXT_SPAN_KEY.get(filteredContext));
+
+    verify(spyServerSpan, never()).end(any(EndSpanOptions.class));
+    tracer.streamClosed(Status.CANCELLED);
+
+    verify(spyServerSpan).end(
+        EndSpanOptions.builder()
+            .setStatus(com.google.instrumentation.trace.Status.CANCELLED).build());
+    verify(spyServerSpan, never()).end();
+  }
+
+  @Test
+  public void convertToTracingStatus() {
+    // Without description
+    for (Status.Code grpcCode : Status.Code.values()) {
+      Status grpcStatus = Status.fromCode(grpcCode);
+      com.google.instrumentation.trace.Status tracingStatus =
+          CensusTracingModule.convertStatus(grpcStatus);
+      assertEquals(grpcCode.toString(), tracingStatus.getCanonicalCode().toString());
+      assertNull(tracingStatus.getDescription());
+    }
+
+    // With description
+    for (Status.Code grpcCode : Status.Code.values()) {
+      Status grpcStatus = Status.fromCode(grpcCode).withDescription("This is my description");
+      com.google.instrumentation.trace.Status tracingStatus =
+          CensusTracingModule.convertStatus(grpcStatus);
+      assertEquals(grpcCode.toString(), tracingStatus.getCanonicalCode().toString());
+      assertEquals(grpcStatus.getDescription(), tracingStatus.getDescription());
+    }
+  }
+
+  private static void assertNoServerContent(StatsTestUtils.MetricsRecord record) {
+    assertNull(record.getMetric(RpcConstants.RPC_SERVER_ERROR_COUNT));
+    assertNull(record.getMetric(RpcConstants.RPC_SERVER_REQUEST_BYTES));
+    assertNull(record.getMetric(RpcConstants.RPC_SERVER_RESPONSE_BYTES));
+    assertNull(record.getMetric(RpcConstants.RPC_SERVER_SERVER_ELAPSED_TIME));
+    assertNull(record.getMetric(RpcConstants.RPC_SERVER_SERVER_LATENCY));
+    assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_REQUEST_BYTES));
+    assertNull(record.getMetric(RpcConstants.RPC_SERVER_UNCOMPRESSED_RESPONSE_BYTES));
+  }
+
+  private static void assertNoClientContent(StatsTestUtils.MetricsRecord record) {
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_ERROR_COUNT));
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_REQUEST_BYTES));
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_RESPONSE_BYTES));
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_ROUNDTRIP_LATENCY));
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_SERVER_ELAPSED_TIME));
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES));
+    assertNull(record.getMetric(RpcConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES));
+  }
+
+  // Promote the visibility of SpanFactory's methods to allow mocking
+  private abstract static class AccessibleSpanFactory extends SpanFactory {
+    @Override
+    public abstract Span startSpan(@Nullable Span parent, String name, StartSpanOptions options);
+
+    @Override
+    public abstract Span startSpanWithRemoteParent(
+        @Nullable SpanContext remoteParent, String name, StartSpanOptions options);
+  }
+
+  private static class FakeSpan extends Span {
+    FakeSpan(SpanContext ctx) {
+      super(ctx, null);
+    }
+
+    @Override
+    public void addAttributes(Map<String, AttributeValue> attributes) {
+    }
+
+    @Override
+    public void addAnnotation(String description, Map<String, AttributeValue> attributes) {
+    }
+
+    @Override
+    public void addAnnotation(Annotation annotation) {
+    }
+
+    @Override
+    public void addNetworkEvent(NetworkEvent networkEvent) {
+    }
+
+    @Override
+    public void addLink(Link link) {
+    }
+
+    @Override
+    public void end(EndSpanOptions options) {
+    }
+  }
+}