Retry DNS resolution when there is an error.
diff --git a/core/src/main/java/io/grpc/DnsNameResolver.java b/core/src/main/java/io/grpc/DnsNameResolver.java
index ce787eb..ae5fa04 100644
--- a/core/src/main/java/io/grpc/DnsNameResolver.java
+++ b/core/src/main/java/io/grpc/DnsNameResolver.java
@@ -31,16 +31,21 @@
package io.grpc;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-import io.grpc.internal.GrpcUtil;
import io.grpc.internal.SharedResourceHolder;
+import io.grpc.internal.SharedResourceHolder.Resource;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.URI;
+import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
@@ -50,21 +55,33 @@
*
* @see DnsNameResolverFactory
*/
-final class DnsNameResolver extends NameResolver {
+class DnsNameResolver extends NameResolver {
private final String authority;
private final String host;
private final int port;
+ private final Resource<ScheduledExecutorService> timerServiceResource;
+ private final Resource<ExecutorService> executorResource;
+ @GuardedBy("this")
+ private boolean shutdown;
+ @GuardedBy("this")
+ private ScheduledExecutorService timerService;
@GuardedBy("this")
private ExecutorService executor;
@GuardedBy("this")
+ private ScheduledFuture<?> resolutionTask;
+ @GuardedBy("this")
private boolean resolving;
@GuardedBy("this")
private Listener listener;
- DnsNameResolver(@Nullable String nsAuthority, String name, Attributes params) {
+ DnsNameResolver(@Nullable String nsAuthority, String name, Attributes params,
+ Resource<ScheduledExecutorService> timerServiceResource,
+ Resource<ExecutorService> executorResource) {
// TODO: if a DNS server is provided as nsAuthority, use it.
// https://www.captechconsulting.com/blogs/accessing-the-dusty-corners-of-dns-with-java
+ this.timerServiceResource = timerServiceResource;
+ this.executorResource = executorResource;
// Must prepend a "//" to the name when constructing a URI, otherwise it will be treated as an
// opaque URI, thus the authority and host of the resulted URI would be null.
URI nameUri = URI.create("//" + name);
@@ -85,39 +102,55 @@
}
@Override
- public String getServiceAuthority() {
+ public final String getServiceAuthority() {
return authority;
}
@Override
- public synchronized void start(Listener listener) {
- Preconditions.checkState(executor == null, "already started");
- executor = SharedResourceHolder.get(GrpcUtil.SHARED_CHANNEL_EXECUTOR);
- this.listener = listener;
+ public final synchronized void start(Listener listener) {
+ Preconditions.checkState(this.listener == null, "already started");
+ timerService = SharedResourceHolder.get(timerServiceResource);
+ executor = SharedResourceHolder.get(executorResource);
+ this.listener = Preconditions.checkNotNull(listener, "listener");
resolve();
}
@Override
- public synchronized void refresh() {
- Preconditions.checkState(executor != null, "not started");
+ public final synchronized void refresh() {
+ Preconditions.checkState(listener != null, "not started");
resolve();
}
- @GuardedBy("this")
- private void resolve() {
- if (resolving) {
- return;
- }
- resolving = true;
- final Listener savedListener = Preconditions.checkNotNull(listener);
- executor.execute(new Runnable() {
+ private final Runnable resolutionRunnable = new Runnable() {
@Override
public void run() {
InetAddress[] inetAddrs;
+ Listener savedListener;
+ synchronized (DnsNameResolver.this) {
+ // If this task is started by refresh(), there might already be a scheduled task.
+ if (resolutionTask != null) {
+ resolutionTask.cancel(false);
+ resolutionTask = null;
+ }
+ if (shutdown) {
+ return;
+ }
+ savedListener = listener;
+ resolving = true;
+ }
try {
try {
- inetAddrs = InetAddress.getAllByName(host);
- } catch (Exception e) {
+ inetAddrs = getAllByName(host);
+ } catch (UnknownHostException e) {
+ synchronized (DnsNameResolver.this) {
+ if (shutdown) {
+ return;
+ }
+ // Because timerService is the single-threaded GrpcUtil.TIMER_SERVICE in production,
+ // we need to delegate the blocking work to the executor
+ resolutionTask = timerService.schedule(resolutionRunnableOnExecutor,
+ 1, TimeUnit.MINUTES);
+ }
savedListener.onError(Status.UNAVAILABLE.withCause(e));
return;
}
@@ -135,17 +168,51 @@
}
}
}
- });
+ };
+
+ private final Runnable resolutionRunnableOnExecutor = new Runnable() {
+ @Override
+ public void run() {
+ synchronized (DnsNameResolver.this) {
+ if (!shutdown) {
+ executor.execute(resolutionRunnable);
+ }
+ }
+ }
+ };
+
+ // To be mocked out in tests
+ @VisibleForTesting
+ InetAddress[] getAllByName(String host) throws UnknownHostException {
+ return InetAddress.getAllByName(host);
+ }
+
+ @GuardedBy("this")
+ private void resolve() {
+ if (resolving || shutdown) {
+ return;
+ }
+ executor.execute(resolutionRunnable);
}
@Override
- public synchronized void shutdown() {
+ public final synchronized void shutdown() {
+ if (shutdown) {
+ return;
+ }
+ shutdown = true;
+ if (resolutionTask != null) {
+ resolutionTask.cancel(false);
+ }
+ if (timerService != null) {
+ timerService = SharedResourceHolder.release(timerServiceResource, timerService);
+ }
if (executor != null) {
- executor = SharedResourceHolder.release(GrpcUtil.SHARED_CHANNEL_EXECUTOR, executor);
+ executor = SharedResourceHolder.release(executorResource, executor);
}
}
- int getPort() {
+ final int getPort() {
return port;
}
}
diff --git a/core/src/main/java/io/grpc/DnsNameResolverFactory.java b/core/src/main/java/io/grpc/DnsNameResolverFactory.java
index 39c2238..6ee236d 100644
--- a/core/src/main/java/io/grpc/DnsNameResolverFactory.java
+++ b/core/src/main/java/io/grpc/DnsNameResolverFactory.java
@@ -33,6 +33,8 @@
import com.google.common.base.Preconditions;
+import io.grpc.internal.GrpcUtil;
+
import java.net.URI;
/**
@@ -63,7 +65,8 @@
Preconditions.checkArgument(targetPath.startsWith("/"),
"the path component (%s) of the target (%s) must start with '/'", targetPath, targetUri);
String name = targetPath.substring(1);
- return new DnsNameResolver(targetUri.getAuthority(), name, params);
+ return new DnsNameResolver(targetUri.getAuthority(), name, params, GrpcUtil.TIMER_SERVICE,
+ GrpcUtil.SHARED_CHANNEL_EXECUTOR);
} else {
return null;
}
diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java
index f604a87..0ed5d4e 100644
--- a/core/src/main/java/io/grpc/internal/GrpcUtil.java
+++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java
@@ -385,7 +385,7 @@
};
/**
- * Shared executor for managing channel timers.
+ * Shared single-threaded executor for managing channel timers.
*/
public static final Resource<ScheduledExecutorService> TIMER_SERVICE =
new Resource<ScheduledExecutorService>() {
diff --git a/core/src/test/java/io/grpc/DnsNameResolverTest.java b/core/src/test/java/io/grpc/DnsNameResolverTest.java
index e034a6d..92fad2b 100644
--- a/core/src/test/java/io/grpc/DnsNameResolverTest.java
+++ b/core/src/test/java/io/grpc/DnsNameResolverTest.java
@@ -33,24 +33,82 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertSame;
import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import io.grpc.internal.FakeClock;
+import io.grpc.internal.SharedResourceHolder.Resource;
+
+import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
import java.net.URI;
+import java.net.UnknownHostException;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
/** Unit tests for {@link DnsNameResolver}. */
@RunWith(JUnit4.class)
public class DnsNameResolverTest {
-
- private DnsNameResolverFactory factory = DnsNameResolverFactory.getInstance();
-
private static final int DEFAULT_PORT = 887;
private static final Attributes NAME_RESOLVER_PARAMS =
Attributes.newBuilder().set(NameResolver.Factory.PARAMS_DEFAULT_PORT, DEFAULT_PORT).build();
+ private final DnsNameResolverFactory factory = DnsNameResolverFactory.getInstance();
+ private final FakeClock fakeClock = new FakeClock();
+ private final Resource<ScheduledExecutorService> fakeTimerService =
+ new Resource<ScheduledExecutorService>() {
+ @Override
+ public ScheduledExecutorService create() {
+ return fakeClock.scheduledExecutorService;
+ }
+
+ @Override
+ public void close(ScheduledExecutorService instance) {
+ assertSame(fakeClock, instance);
+ }
+ };
+
+ private final Resource<ExecutorService> fakeExecutor =
+ new Resource<ExecutorService>() {
+ @Override
+ public ExecutorService create() {
+ return fakeClock.scheduledExecutorService;
+ }
+
+ @Override
+ public void close(ExecutorService instance) {
+ assertSame(fakeClock, instance);
+ }
+ };
+
+ @Mock
+ private NameResolver.Listener mockListener;
+ @Captor
+ private ArgumentCaptor<List<ResolvedServerInfo>> resultCaptor;
+ @Captor
+ private ArgumentCaptor<Status> statusCaptor;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ }
+
@Test
public void invalidDnsName() throws Exception {
testInvalidUri(new URI("dns", null, "/[invalid]", null));
@@ -73,6 +131,116 @@
"foo.googleapis.com:456", 456);
}
+ @Test
+ public void resolve() throws Exception {
+ InetAddress[] answer1 = createAddressList(2);
+ InetAddress[] answer2 = createAddressList(1);
+ String name = "foo.googleapis.com";
+ MockResolver resolver = new MockResolver(name, 81, answer1, answer2);
+ resolver.start(mockListener);
+ verify(mockListener).onUpdate(resultCaptor.capture(), any(Attributes.class));
+ assertEquals(name, resolver.invocations.poll());
+ assertAnswerMatches(answer1, 81, resultCaptor.getValue());
+ assertEquals(0, fakeClock.numPendingTasks());
+
+ resolver.refresh();
+ verify(mockListener, times(2)).onUpdate(resultCaptor.capture(), any(Attributes.class));
+ assertEquals(name, resolver.invocations.poll());
+ assertAnswerMatches(answer2, 81, resultCaptor.getValue());
+ assertEquals(0, fakeClock.numPendingTasks());
+
+ resolver.shutdown();
+ }
+
+ @Test
+ public void retry() throws Exception {
+ String name = "foo.googleapis.com";
+ UnknownHostException error = new UnknownHostException(name);
+ InetAddress[] answer = createAddressList(2);
+ MockResolver resolver = new MockResolver(name, 81, error, error, answer);
+ resolver.start(mockListener);
+ verify(mockListener).onError(statusCaptor.capture());
+ assertEquals(name, resolver.invocations.poll());
+ Status status = statusCaptor.getValue();
+ assertEquals(Status.Code.UNAVAILABLE, status.getCode());
+ assertSame(error, status.getCause());
+
+ // First retry scheduled
+ assertEquals(1, fakeClock.numPendingTasks());
+ fakeClock.forwardMillis(TimeUnit.MINUTES.toMillis(1) - 1);
+ assertEquals(1, fakeClock.numPendingTasks());
+
+ // First retry
+ fakeClock.forwardMillis(1);
+ verify(mockListener, times(2)).onError(statusCaptor.capture());
+ assertEquals(name, resolver.invocations.poll());
+ status = statusCaptor.getValue();
+ assertEquals(Status.Code.UNAVAILABLE, status.getCode());
+ assertSame(error, status.getCause());
+
+ // Second retry scheduled
+ assertEquals(1, fakeClock.numPendingTasks());
+ fakeClock.forwardMillis(TimeUnit.MINUTES.toMillis(1) - 1);
+ assertEquals(1, fakeClock.numPendingTasks());
+
+ // Second retry
+ fakeClock.forwardMillis(1);
+ assertEquals(0, fakeClock.numPendingTasks());
+ verify(mockListener).onUpdate(resultCaptor.capture(), any(Attributes.class));
+ assertEquals(name, resolver.invocations.poll());
+ assertAnswerMatches(answer, 81, resultCaptor.getValue());
+
+ verifyNoMoreInteractions(mockListener);
+ }
+
+ @Test
+ public void refreshCancelsScheduledRetry() throws Exception {
+ String name = "foo.googleapis.com";
+ UnknownHostException error = new UnknownHostException(name);
+ InetAddress[] answer = createAddressList(2);
+ MockResolver resolver = new MockResolver(name, 81, error, answer);
+ resolver.start(mockListener);
+ verify(mockListener).onError(statusCaptor.capture());
+ assertEquals(name, resolver.invocations.poll());
+ Status status = statusCaptor.getValue();
+ assertEquals(Status.Code.UNAVAILABLE, status.getCode());
+ assertSame(error, status.getCause());
+
+ // First retry scheduled
+ assertEquals(1, fakeClock.numPendingTasks());
+
+ resolver.refresh();
+ // Refresh cancelled the retry
+ assertEquals(0, fakeClock.numPendingTasks());
+ verify(mockListener).onUpdate(resultCaptor.capture(), any(Attributes.class));
+ assertEquals(name, resolver.invocations.poll());
+ assertAnswerMatches(answer, 81, resultCaptor.getValue());
+
+ verifyNoMoreInteractions(mockListener);
+ }
+
+ @Test
+ public void shutdownCancelsScheduledRetry() throws Exception {
+ String name = "foo.googleapis.com";
+ UnknownHostException error = new UnknownHostException(name);
+ MockResolver resolver = new MockResolver(name, 81, error);
+ resolver.start(mockListener);
+ verify(mockListener).onError(statusCaptor.capture());
+ assertEquals(name, resolver.invocations.poll());
+ Status status = statusCaptor.getValue();
+ assertEquals(Status.Code.UNAVAILABLE, status.getCode());
+ assertSame(error, status.getCause());
+
+ // Retry scheduled
+ assertEquals(1, fakeClock.numPendingTasks());
+
+ // Shutdown cancelled the retry
+ resolver.shutdown();
+ assertEquals(0, fakeClock.numPendingTasks());
+
+ verifyNoMoreInteractions(mockListener);
+ }
+
private void testInvalidUri(URI uri) {
try {
factory.newNameResolver(uri, NAME_RESOLVER_PARAMS);
@@ -88,4 +256,48 @@
assertEquals(expectedPort, resolver.getPort());
assertEquals(exportedAuthority, resolver.getServiceAuthority());
}
+
+ private byte lastByte = 0;
+
+ private InetAddress[] createAddressList(int n) throws UnknownHostException {
+ InetAddress[] list = new InetAddress[n];
+ for (int i = 0; i < n; i++) {
+ list[i] = InetAddress.getByAddress(new byte[] {127, 0, 0, ++lastByte});
+ }
+ return list;
+ }
+
+ private static void assertAnswerMatches(InetAddress[] addrs, int port,
+ List<ResolvedServerInfo> result) {
+ assertEquals(addrs.length, result.size());
+ for (int i = 0; i < addrs.length; i++) {
+ InetSocketAddress socketAddr = (InetSocketAddress) result.get(i).getAddress();
+ assertEquals("Addr " + i, port, socketAddr.getPort());
+ assertEquals("Addr " + i, addrs[i], socketAddr.getAddress());
+ }
+ }
+
+ private class MockResolver extends DnsNameResolver {
+ final LinkedList<Object> answers = new LinkedList<Object>();
+ final LinkedList<String> invocations = new LinkedList<String>();
+
+ MockResolver(String name, int defaultPort, Object ... answers) {
+ super(null, name, Attributes.newBuilder().set(
+ NameResolver.Factory.PARAMS_DEFAULT_PORT, defaultPort).build(), fakeTimerService,
+ fakeExecutor);
+ for (Object answer : answers) {
+ this.answers.add(answer);
+ }
+ }
+
+ @Override
+ InetAddress[] getAllByName(String host) throws UnknownHostException {
+ invocations.add(host);
+ Object answer = answers.poll();
+ if (answer instanceof UnknownHostException) {
+ throw (UnknownHostException) answer;
+ }
+ return (InetAddress[]) answer;
+ }
+ }
}
diff --git a/core/src/test/java/io/grpc/internal/FakeClock.java b/core/src/test/java/io/grpc/internal/FakeClock.java
index 9f5f8cc..d0db8f2 100644
--- a/core/src/test/java/io/grpc/internal/FakeClock.java
+++ b/core/src/test/java/io/grpc/internal/FakeClock.java
@@ -47,9 +47,9 @@
/**
* A manipulated clock that exports a {@link Ticker} and a {@link ScheduledExecutorService}.
*/
-final class FakeClock {
+public final class FakeClock {
- final ScheduledExecutorService scheduledExecutorService = new ScheduledExecutorImpl();
+ public final ScheduledExecutorService scheduledExecutorService = new ScheduledExecutorImpl();
final Ticker ticker = new Ticker() {
@Override public long read() {
return TimeUnit.MILLISECONDS.toNanos(currentTimeNanos);
@@ -183,12 +183,16 @@
}
}
- void forwardTime(long value, TimeUnit unit) {
+ public void forwardTime(long value, TimeUnit unit) {
currentTimeNanos += unit.toNanos(value);
runDueTasks();
}
- void forwardMillis(long millis) {
+ public void forwardMillis(long millis) {
forwardTime(millis, TimeUnit.MILLISECONDS);
}
+
+ public int numPendingTasks() {
+ return tasks.size();
+ }
}