core: OverrideAuthorityNameResolverFactory should forward refresh() (#3061)

The current implementation has a bug where certain methods are not forwarded to the delegate.

This is essentially the same as e4f1f39 which was merged to the v1.4.x branch. This PR uses the new license header.

Fixes #3061
diff --git a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java
index 386d95c..29ce532 100644
--- a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java
+++ b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java
@@ -93,6 +93,9 @@
 
   private final List<ClientInterceptor> interceptors = new ArrayList<ClientInterceptor>();
 
+  // Access via getter, which may perform authority override as needed
+  private NameResolver.Factory nameResolverFactory = DEFAULT_NAME_RESOLVER_FACTORY;
+
   final String target;
 
   @Nullable
@@ -101,10 +104,10 @@
   @Nullable
   String userAgent;
 
+  @VisibleForTesting
   @Nullable
   String authorityOverride;
 
-  NameResolver.Factory nameResolverFactory = DEFAULT_NAME_RESOLVER_FACTORY;
 
   LoadBalancer.Factory loadBalancerFactory = DEFAULT_LOAD_BALANCER_FACTORY;
 
@@ -367,6 +370,17 @@
     return Attributes.EMPTY;
   }
 
+  /**
+   * Returns a {@link NameResolver.Factory} for the channel.
+   */
+  NameResolver.Factory getNameResolverFactory() {
+    if (authorityOverride == null) {
+      return nameResolverFactory;
+    } else {
+      return new OverrideAuthorityNameResolverFactory(nameResolverFactory, authorityOverride);
+    }
+  }
+
   private static class DirectAddressNameResolverFactory extends NameResolver.Factory {
     final SocketAddress address;
     final String authority;
diff --git a/core/src/main/java/io/grpc/internal/ForwardingNameResolver.java b/core/src/main/java/io/grpc/internal/ForwardingNameResolver.java
new file mode 100644
index 0000000..768f82e
--- /dev/null
+++ b/core/src/main/java/io/grpc/internal/ForwardingNameResolver.java
@@ -0,0 +1,53 @@
+/*
+ * Copyright 2017, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.internal;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+
+import io.grpc.NameResolver;
+
+/**
+* A forwarding class to ensure non overridden methods are forwarded to the delegate.
+ */
+abstract class ForwardingNameResolver extends NameResolver {
+  private final NameResolver delegate;
+
+  ForwardingNameResolver(NameResolver delegate) {
+    checkNotNull(delegate, "delegate can not be null");
+    this.delegate = delegate;
+  }
+
+  @Override
+  public String getServiceAuthority() {
+    return delegate.getServiceAuthority();
+  }
+
+  @Override
+  public void start(Listener listener) {
+    delegate.start(listener);
+  }
+
+  @Override
+  public void shutdown() {
+    delegate.shutdown();
+  }
+
+  @Override
+  public void refresh() {
+    delegate.refresh();
+  }
+}
diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
index 4613676..2b5195c 100644
--- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
+++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
@@ -373,12 +373,7 @@
       Supplier<Stopwatch> stopwatchSupplier,
       List<ClientInterceptor> interceptors) {
     this.target = checkNotNull(builder.target, "target");
-    NameResolver.Factory tmpNameResolverFactory = builder.nameResolverFactory;
-    if (builder.authorityOverride != null) {
-      tmpNameResolverFactory = new OverrideAuthorityNameResolverFactory(
-          tmpNameResolverFactory, builder.authorityOverride);
-    }
-    this.nameResolverFactory = tmpNameResolverFactory;
+    this.nameResolverFactory = builder.getNameResolverFactory();
     this.nameResolverParams = checkNotNull(builder.getNameResolverParams(), "nameResolverParams");
     this.nameResolver = getNameResolver(target, nameResolverFactory, nameResolverParams);
     this.loadBalancerFactory =
diff --git a/core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java b/core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java
index 6aaab82..77a0b43 100644
--- a/core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java
+++ b/core/src/main/java/io/grpc/internal/OverrideAuthorityNameResolverFactory.java
@@ -49,21 +49,11 @@
     if (resolver == null) {
       return null;
     }
-    return new NameResolver() {
+    return new ForwardingNameResolver(resolver) {
       @Override
       public String getServiceAuthority() {
         return authorityOverride;
       }
-
-      @Override
-      public void start(Listener listener) {
-        resolver.start(listener);
-      }
-
-      @Override
-      public void shutdown() {
-        resolver.shutdown();
-      }
     };
   }
 
diff --git a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java
index 1bbc958..740049a 100644
--- a/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java
+++ b/core/src/test/java/io/grpc/internal/AbstractManagedChannelImplBuilderTest.java
@@ -16,17 +16,16 @@
 
 package io.grpc.internal;
 
+import static junit.framework.TestCase.assertFalse;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 import com.google.common.util.concurrent.MoreExecutors;
-import io.grpc.Attributes;
 import io.grpc.CompressorRegistry;
 import io.grpc.DecompressorRegistry;
 import io.grpc.LoadBalancer;
@@ -74,22 +73,22 @@
 
   @Test
   public void nameResolverFactory_default() {
-    assertNotNull(builder.nameResolverFactory);
+    assertNotNull(builder.getNameResolverFactory());
   }
 
   @Test
   public void nameResolverFactory_normal() {
     NameResolver.Factory nameResolverFactory = mock(NameResolver.Factory.class);
     assertEquals(builder, builder.nameResolverFactory(nameResolverFactory));
-    assertEquals(nameResolverFactory, builder.nameResolverFactory);
+    assertEquals(nameResolverFactory, builder.getNameResolverFactory());
   }
 
   @Test
   public void nameResolverFactory_null() {
-    NameResolver.Factory defaultValue = builder.nameResolverFactory;
+    NameResolver.Factory defaultValue = builder.getNameResolverFactory();
     builder.nameResolverFactory(mock(NameResolver.Factory.class));
     assertEquals(builder, builder.nameResolverFactory(null));
-    assertEquals(defaultValue, builder.nameResolverFactory);
+    assertEquals(defaultValue, builder.getNameResolverFactory());
   }
 
   @Test(expected = IllegalStateException.class)
@@ -211,6 +210,15 @@
   }
 
   @Test
+  public void overrideAuthority_getNameResolverFactory() {
+    Builder builder = new Builder("target");
+    assertNull(builder.authorityOverride);
+    assertFalse(builder.getNameResolverFactory() instanceof OverrideAuthorityNameResolverFactory);
+    builder.overrideAuthority("google.com");
+    assertTrue(builder.getNameResolverFactory() instanceof OverrideAuthorityNameResolverFactory);
+  }
+
+  @Test
   public void makeTargetStringForDirectAddress_scopedIpv6() throws Exception {
     InetSocketAddress address = new InetSocketAddress("0:0:0:0:0:0:0:0%0", 10005);
     assertEquals("/0:0:0:0:0:0:0:0%0:10005", address.toString());
@@ -249,31 +257,6 @@
     assertEquals(TimeUnit.SECONDS.toMillis(30), builder.getIdleTimeoutMillis());
   }
 
-  @Test
-  public void overrideAuthorityNameResolverWrapsDelegateTest() {
-    NameResolver nameResolverMock = mock(NameResolver.class);
-    NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
-    when(wrappedFactory.newNameResolver(any(URI.class), any(Attributes.class)))
-      .thenReturn(nameResolverMock);
-    String override = "override:5678";
-    NameResolver.Factory factory =
-        new OverrideAuthorityNameResolverFactory(wrappedFactory, override);
-    NameResolver nameResolver = factory.newNameResolver(URI.create("dns:///localhost:443"),
-        Attributes.EMPTY);
-    assertNotNull(nameResolver);
-    assertEquals(override, nameResolver.getServiceAuthority());
-  }
-
-  @Test
-  public void overrideAuthorityNameResolverWontWrapNullTest() {
-    NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
-    when(wrappedFactory.newNameResolver(any(URI.class), any(Attributes.class))).thenReturn(null);
-    NameResolver.Factory factory =
-        new OverrideAuthorityNameResolverFactory(wrappedFactory, "override:5678");
-    assertEquals(null,
-        factory.newNameResolver(URI.create("dns:///localhost:443"), Attributes.EMPTY));
-  }
-
   static class Builder extends AbstractManagedChannelImplBuilder<Builder> {
     Builder(String target) {
       super(target);
diff --git a/core/src/test/java/io/grpc/internal/OverrideAuthorityNameResolverTest.java b/core/src/test/java/io/grpc/internal/OverrideAuthorityNameResolverTest.java
new file mode 100644
index 0000000..5dbdd83
--- /dev/null
+++ b/core/src/test/java/io/grpc/internal/OverrideAuthorityNameResolverTest.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright 2017, gRPC Authors All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.internal;
+
+import static junit.framework.TestCase.assertNotNull;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import io.grpc.Attributes;
+import io.grpc.NameResolver;
+import java.net.URI;
+import org.junit.Test;
+
+public class OverrideAuthorityNameResolverTest {
+  @Test
+  public void overridesAuthority() {
+    NameResolver nameResolverMock = mock(NameResolver.class);
+    NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
+    when(wrappedFactory.newNameResolver(any(URI.class), any(Attributes.class)))
+        .thenReturn(nameResolverMock);
+    String override = "override:5678";
+    NameResolver.Factory factory =
+        new OverrideAuthorityNameResolverFactory(wrappedFactory, override);
+    NameResolver nameResolver = factory.newNameResolver(URI.create("dns:///localhost:443"),
+        Attributes.EMPTY);
+    assertNotNull(nameResolver);
+    assertEquals(override, nameResolver.getServiceAuthority());
+  }
+
+  @Test
+  public void wontWrapNull() {
+    NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
+    when(wrappedFactory.newNameResolver(any(URI.class), any(Attributes.class))).thenReturn(null);
+    NameResolver.Factory factory =
+        new OverrideAuthorityNameResolverFactory(wrappedFactory, "override:5678");
+    assertEquals(null,
+        factory.newNameResolver(URI.create("dns:///localhost:443"), Attributes.EMPTY));
+  }
+
+  @Test
+  public void forwardsNonOverridenCalls() {
+    NameResolver.Factory wrappedFactory = mock(NameResolver.Factory.class);
+    NameResolver mockResolver = mock(NameResolver.class);
+    when(wrappedFactory.newNameResolver(any(URI.class), any(Attributes.class)))
+        .thenReturn(mockResolver);
+    NameResolver.Factory factory =
+        new OverrideAuthorityNameResolverFactory(wrappedFactory, "override:5678");
+    NameResolver overrideResolver =
+        factory.newNameResolver(URI.create("dns:///localhost:443"), Attributes.EMPTY);
+    assertNotNull(overrideResolver);
+    NameResolver.Listener listener = mock(NameResolver.Listener.class);
+
+    overrideResolver.start(listener);
+    verify(mockResolver).start(listener);
+
+    overrideResolver.shutdown();
+    verify(mockResolver).shutdown();
+
+    overrideResolver.refresh();
+    verify(mockResolver).refresh();
+  }
+}