blob: 4a27f79d53036c710f118c8f7f054f9002a95faa [file] [log] [blame]
/*
* Copyright 2015, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.internal;
import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import io.grpc.Attributes;
import io.grpc.ClientStreamTracer;
import io.grpc.Codec;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.AbstractClientStream.TransportState;
import io.grpc.internal.MessageFramerTest.ByteWritableBuffer;
import java.io.ByteArrayInputStream;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/**
* Test for {@link AbstractClientStream}. This class tries to test functionality in
* AbstractClientStream2, but not in any super classes.
*/
@RunWith(JUnit4.class)
public class AbstractClientStreamTest {
@Rule public final ExpectedException thrown = ExpectedException.none();
private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP;
@Mock private ClientStreamListener mockListener;
@Captor private ArgumentCaptor<Status> statusCaptor;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
}
private final WritableBufferAllocator allocator = new WritableBufferAllocator() {
@Override
public WritableBuffer allocate(int capacityHint) {
return new ByteWritableBuffer(capacityHint);
}
};
@Test
public void cancel_doNotAcceptOk() {
for (Code code : Code.values()) {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(listener);
if (code != Code.OK) {
stream.cancel(Status.fromCodeValue(code.value()));
} else {
try {
stream.cancel(Status.fromCodeValue(code.value()));
fail();
} catch (IllegalArgumentException e) {
// ignore
}
}
}
}
@Test
public void cancel_failsOnNull() {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(listener);
thrown.expect(NullPointerException.class);
stream.cancel(null);
}
@Test
public void cancel_notifiesOnlyOnce() {
final BaseTransportState state = new BaseTransportState(statsTraceCtx);
AbstractClientStream stream = new BaseAbstractClientStream(allocator, state, new BaseSink() {
@Override
public void cancel(Status errorStatus) {
// Cancel should eventually result in a transportReportStatus on the transport thread
state.transportReportStatus(errorStatus, true/*stop delivery*/, new Metadata());
}
}, statsTraceCtx);
stream.start(mockListener);
stream.cancel(Status.DEADLINE_EXCEEDED);
stream.cancel(Status.DEADLINE_EXCEEDED);
verify(mockListener).closed(any(Status.class), any(Metadata.class));
}
@Test
public void startFailsOnNullListener() {
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
thrown.expect(NullPointerException.class);
stream.start(null);
}
@Test
public void cantCallStartTwice() {
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener);
thrown.expect(IllegalStateException.class);
stream.start(mockListener);
}
@Test
public void inboundDataReceived_failsOnNullFrame() {
ClientStreamListener listener = new NoopClientStreamListener();
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(listener);
TransportState state = stream.transportState();
thrown.expect(NullPointerException.class);
state.inboundDataReceived(null);
}
@Test
public void inboundHeadersReceived_notifiesListener() {
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener);
Metadata headers = new Metadata();
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void inboundHeadersReceived_failsIfStatusReported() {
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener);
stream.transportState().transportReportStatus(Status.CANCELLED, false, new Metadata());
TransportState state = stream.transportState();
thrown.expect(IllegalStateException.class);
state.inboundHeadersReceived(new Metadata());
}
@Test
public void inboundHeadersReceived_acceptsGzipEncoding() {
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, new Codec.Gzip().getMessageEncoding());
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void inboundHeadersReceived_acceptsIdentityEncoding() {
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener);
Metadata headers = new Metadata();
headers.put(GrpcUtil.MESSAGE_ENCODING_KEY, Codec.Identity.NONE.getMessageEncoding());
stream.transportState().inboundHeadersReceived(headers);
verify(mockListener).headersRead(headers);
}
@Test
public void rstStreamClosesStream() {
AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx);
stream.start(mockListener);
// The application will call request when waiting for a message, which will in turn call this
// on the transport thread.
stream.transportState().requestMessagesFromDeframer(1);
// Send first byte of 2 byte message
stream.transportState().deframe(ReadableBuffers.wrap(new byte[] {0, 0, 0, 0, 2, 1}), false);
Status status = Status.INTERNAL;
// Simulate getting a reset
stream.transportState().transportReportStatus(status, false /*stop delivery*/, new Metadata());
verify(mockListener).closed(any(Status.class), any(Metadata.class));
}
@Test
public void getRequest() {
AbstractClientStream.Sink sink = mock(AbstractClientStream.Sink.class);
final ClientStreamTracer tracer = spy(new ClientStreamTracer() {});
ClientStreamTracer.Factory tracerFactory =
new ClientStreamTracer.Factory() {
@Override
public ClientStreamTracer newClientStreamTracer(Metadata headers) {
return tracer;
}
};
StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[] {tracer});
AbstractClientStream stream = new BaseAbstractClientStream(allocator,
new BaseTransportState(statsTraceCtx), sink, statsTraceCtx, true);
stream.start(mockListener);
stream.writeMessage(new ByteArrayInputStream(new byte[1]));
// writeHeaders will be delayed since we're sending a GET request.
verify(sink, never()).writeHeaders(any(Metadata.class), any(byte[].class));
// halfClose will trigger writeHeaders.
stream.halfClose();
ArgumentCaptor<byte[]> payloadCaptor = ArgumentCaptor.forClass(byte[].class);
verify(sink).writeHeaders(any(Metadata.class), payloadCaptor.capture());
assertTrue(payloadCaptor.getValue() != null);
// GET requests don't have BODY.
verify(sink, never())
.writeFrame(any(WritableBuffer.class), any(Boolean.class), any(Boolean.class));
verify(tracer).outboundMessage();
verify(tracer).outboundWireSize(1);
verify(tracer).outboundUncompressedSize(1);
verifyNoMoreInteractions(tracer);
}
/**
* No-op base class for testing.
*/
private static class BaseAbstractClientStream extends AbstractClientStream {
private final TransportState state;
private final Sink sink;
public BaseAbstractClientStream(WritableBufferAllocator allocator,
StatsTraceContext statsTraceCtx) {
this(allocator, new BaseTransportState(statsTraceCtx), new BaseSink(), statsTraceCtx);
}
public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state,
Sink sink, StatsTraceContext statsTraceCtx) {
this(allocator, state, sink, statsTraceCtx, false);
}
public BaseAbstractClientStream(WritableBufferAllocator allocator, TransportState state,
Sink sink, StatsTraceContext statsTraceCtx, boolean useGet) {
super(allocator, statsTraceCtx, new Metadata(), useGet);
this.state = state;
this.sink = sink;
}
@Override
protected Sink abstractClientStreamSink() {
return sink;
}
@Override
public TransportState transportState() {
return state;
}
@Override
public void setAuthority(String authority) {}
@Override
public void setMaxInboundMessageSize(int maxSize) {}
@Override
public void setMaxOutboundMessageSize(int maxSize) {}
@Override
public Attributes getAttributes() {
return Attributes.EMPTY;
}
}
private static class BaseSink implements AbstractClientStream.Sink {
@Override
public void writeHeaders(Metadata headers, byte[] payload) {}
@Override
public void request(int numMessages) {}
@Override
public void writeFrame(WritableBuffer frame, boolean endOfStream, boolean flush) {}
@Override
public void cancel(Status reason) {}
}
private static class BaseTransportState extends AbstractClientStream.TransportState {
public BaseTransportState(StatsTraceContext statsTraceCtx) {
super(DEFAULT_MAX_MESSAGE_SIZE, statsTraceCtx);
}
@Override
protected void deframeFailed(Throwable cause) {}
@Override
public void bytesRead(int processedBytes) {}
}
}