Improve generics, nullability for NavigationProvider API

Avoid some unnecessary casts and null checks in the consumers
of the NavigationProvider APIs (namely, Navigator subclasses).

Test: Existing tests pass
Change-Id: Ic0c7c65f0eecc75f197492f13ccacfccf87cc4c2
diff --git a/navigation/runtime/src/androidTest/java/android/arch/navigation/BaseNavControllerTest.java b/navigation/runtime/src/androidTest/java/android/arch/navigation/BaseNavControllerTest.java
index 205de13..5ec0912 100644
--- a/navigation/runtime/src/androidTest/java/android/arch/navigation/BaseNavControllerTest.java
+++ b/navigation/runtime/src/androidTest/java/android/arch/navigation/BaseNavControllerTest.java
@@ -119,7 +119,7 @@
         BaseNavigationActivity activity = launchActivity();
         NavController navController = activity.getNavController();
         navController.setGraph(R.navigation.nav_simple);
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.start_test));
         assertThat(navigator.mBackStack.size(), is(1));
@@ -218,7 +218,7 @@
 
         navController.navigate(R.id.second_test, args);
 
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         args = navigator.mBackStack.peekLast().second;
         assertThat(args, is(notNullValue(Bundle.class)));
@@ -231,7 +231,7 @@
         BaseNavigationActivity activity = launchActivity();
         NavController navController = activity.getNavController();
         navController.setGraph(R.navigation.nav_simple);
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.start_test));
         assertThat(navigator.mBackStack.size(), is(1));
@@ -250,7 +250,7 @@
         BaseNavigationActivity activity = launchActivity();
         NavController navController = activity.getNavController();
         navController.setGraph(R.navigation.nav_simple);
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.start_test));
         assertThat(navigator.mBackStack.size(), is(1));
@@ -271,7 +271,7 @@
         NavController navController = activity.getNavController();
         navController.setGraph(R.navigation.nav_simple);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.start_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(1));
 
@@ -287,7 +287,7 @@
         navController.setGraph(R.navigation.nav_simple);
         navController.navigate(R.id.second_test);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.second_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(2));
 
@@ -303,7 +303,7 @@
         navController.setGraph(R.navigation.nav_simple);
         navController.navigate(R.id.second_test);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.second_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(2));
 
@@ -319,7 +319,7 @@
         navController.setGraph(R.navigation.nav_simple);
         navController.navigate(R.id.second_test);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.second_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(2));
 
@@ -352,7 +352,7 @@
         navController.setGraph(R.navigation.nav_simple);
         navController.navigate(R.id.second_test);
         assertThat(navController.getCurrentDestination().getId(), is(R.id.second_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(2));
 
@@ -372,7 +372,7 @@
         args.putString(TEST_OVERRIDDEN_VALUE_ARG, TEST_OVERRIDDEN_VALUE_ARG_VALUE);
         navController.navigate(R.id.second, args);
 
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         Bundle returnedArgs = navigator.mBackStack.peekLast().second;
         assertThat(returnedArgs, is(notNullValue(Bundle.class)));
@@ -395,7 +395,7 @@
         NavController navController = activity.getNavController();
 
         assertThat(navController.getCurrentDestination().getId(), is(R.id.deep_link_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(2));
 
@@ -430,7 +430,7 @@
         NavController navController = activity.getNavController();
 
         assertThat(navController.getCurrentDestination().getId(), is(R.id.deep_link_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(2));
         assertThat(navigator.mBackStack.peekLast().second.getString(TEST_ARG), is(TEST_ARG_VALUE));
@@ -455,7 +455,7 @@
         navController.setGraph(R.navigation.nav_deep_link);
 
         assertThat(navController.getCurrentDestination().getId(), is(R.id.deep_link_test));
-        TestNavigator navigator = (TestNavigator) navController.getNavigatorProvider()
+        TestNavigator navigator = navController.getNavigatorProvider()
                 .getNavigator(TestNavigator.class);
         assertThat(navigator.mBackStack.size(), is(2));
         assertThat(navigator.mBackStack.peekLast().second.getString(TEST_ARG), is(TEST_ARG_VALUE));
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/ActivityNavigator.java b/navigation/runtime/src/main/java/android/arch/navigation/ActivityNavigator.java
index 30195b7..f25fc5e 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/ActivityNavigator.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/ActivityNavigator.java
@@ -165,9 +165,7 @@
          *                          will be associated with.
          */
         public Destination(@NonNull NavigatorProvider navigatorProvider) {
-            //noinspection unchecked
-            this((Navigator<? extends Destination>) navigatorProvider
-                    .getNavigator(ActivityNavigator.class));
+            this(navigatorProvider.getNavigator(ActivityNavigator.class));
         }
 
         /**
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/FragmentNavigator.java b/navigation/runtime/src/main/java/android/arch/navigation/FragmentNavigator.java
index 48f8701..708ad39 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/FragmentNavigator.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/FragmentNavigator.java
@@ -167,9 +167,7 @@
          *                          will be associated with.
          */
         public Destination(@NonNull NavigatorProvider navigatorProvider) {
-            //noinspection unchecked
-            this((Navigator<? extends Destination>) navigatorProvider
-                    .getNavigator(FragmentNavigator.class));
+            this(navigatorProvider.getNavigator(FragmentNavigator.class));
         }
 
         /**
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/NavController.java b/navigation/runtime/src/main/java/android/arch/navigation/NavController.java
index 4e5f125..8f015b3 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/NavController.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/NavController.java
@@ -73,18 +73,19 @@
     private final Deque<NavDestination> mBackStack = new ArrayDeque<>();
 
     private final SimpleNavigatorProvider mNavigatorProvider = new SimpleNavigatorProvider() {
+        @Nullable
         @Override
-        public void addNavigator(String name, Navigator<? extends NavDestination> navigator) {
-            Navigator previousNavigator = getNavigator(name);
-            super.addNavigator(name, navigator);
+        public Navigator<? extends NavDestination> addNavigator(@NonNull String name,
+                @NonNull Navigator<? extends NavDestination> navigator) {
+            Navigator<? extends NavDestination> previousNavigator =
+                    super.addNavigator(name, navigator);
             if (previousNavigator != navigator) {
                 if (previousNavigator != null) {
                     previousNavigator.removeOnNavigatorNavigatedListener(mOnNavigatedListener);
                 }
-                if (navigator != null) {
-                    navigator.addOnNavigatorNavigatedListener(mOnNavigatedListener);
-                }
+                navigator.addOnNavigatorNavigatedListener(mOnNavigatedListener);
             }
+            return previousNavigator;
         }
     };
 
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/NavDeepLinkBuilder.java b/navigation/runtime/src/main/java/android/arch/navigation/NavDeepLinkBuilder.java
index 75239e3..545ee56 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/NavDeepLinkBuilder.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/NavDeepLinkBuilder.java
@@ -275,10 +275,14 @@
             addNavigator(new NavGraphNavigator(context));
         }
 
+        @NonNull
         @Override
-        public Navigator<? extends NavDestination> getNavigator(String name) {
-            Navigator<? extends NavDestination> navigator = super.getNavigator(name);
-            return navigator != null ? navigator : mDestNavigator;
+        public Navigator<? extends NavDestination> getNavigator(@NonNull String name) {
+            try {
+                return super.getNavigator(name);
+            } catch (IllegalStateException e) {
+                return mDestNavigator;
+            }
         }
     }
 }
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/NavGraph.java b/navigation/runtime/src/main/java/android/arch/navigation/NavGraph.java
index b416987..dd9fea8 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/NavGraph.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/NavGraph.java
@@ -51,9 +51,7 @@
      *                          will be associated with.
      */
     public NavGraph(@NonNull NavigatorProvider navigatorProvider) {
-        //noinspection unchecked
-        this((Navigator<? extends NavGraph>) navigatorProvider
-                .getNavigator(NavGraphNavigator.class));
+        this(navigatorProvider.getNavigator(NavGraphNavigator.class));
     }
 
     /**
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/NavInflater.java b/navigation/runtime/src/main/java/android/arch/navigation/NavInflater.java
index 5432dfe..403d374 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/NavInflater.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/NavInflater.java
@@ -72,18 +72,6 @@
     }
 
     /**
-     * Retrieve a Navigator with the given name from the {@link NavigatorProvider} used to
-     * construct this class.
-     *
-     * @param name
-     * @return
-     */
-    @Nullable
-    public Navigator getNavigator(@NonNull String name) {
-        return mNavigatorProvider.getNavigator(name);
-    }
-
-    /**
      * Inflates {@link NavGraph navigation graph} as specified in the application manifest.
      *
      * <p>Applications may declare a graph resource in their manifest instead of declaring
@@ -144,12 +132,7 @@
 
     private NavDestination inflate(Resources res, XmlResourceParser parser, AttributeSet attrs)
             throws XmlPullParserException, IOException {
-        String navigatorName = parser.getName();
-        Navigator navigator = getNavigator(parser.getName());
-        if (navigator == null) {
-            throw new IllegalArgumentException("Could not inflate " + navigatorName
-                    + ". You must call NavController.addNavigator() for each navigation type.");
-        }
+        Navigator navigator = mNavigatorProvider.getNavigator(parser.getName());
         final NavDestination dest = navigator.createDestination();
 
         dest.onInflate(mContext, attrs);
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/NavigatorProvider.java b/navigation/runtime/src/main/java/android/arch/navigation/NavigatorProvider.java
index 558fedc..9056287 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/NavigatorProvider.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/NavigatorProvider.java
@@ -31,11 +31,15 @@
      * @param navigatorClass class of the navigator to return
      * @return the registered navigator with the given {@link Navigator.Name}
      *
+     * @throws IllegalArgumentException if the Navigator does not have a
+     * {@link Navigator.Name Navigator.Name annotation}
+     * @throws IllegalStateException if the Navigator has not been added
+     *
      * @see #addNavigator(Navigator)
      */
-    @Nullable
-    Navigator<? extends NavDestination> getNavigator(
-            @NonNull Class<? extends Navigator> navigatorClass);
+    @NonNull
+    <D extends NavDestination, T extends Navigator<? extends D>> T getNavigator(
+            @NonNull Class<T> navigatorClass);
 
     /**
      * Retrieves a registered {@link Navigator} by name.
@@ -43,10 +47,13 @@
      * @param name name of the navigator to return
      * @return the registered navigator with the given name
      *
+     * @throws IllegalStateException if the Navigator has not been added
+     *
      * @see #addNavigator(String, Navigator)
      */
-    @Nullable
-    Navigator<? extends NavDestination> getNavigator(@NonNull String name);
+    @NonNull
+    <D extends NavDestination, T extends Navigator<? extends D>> T getNavigator(
+            @NonNull String name);
 
     /**
      * Register a navigator using the name provided by the
@@ -55,8 +62,12 @@
      * already registered, this new navigator will replace it.
      *
      * @param navigator navigator to add
+     * @return the previously added Navigator for the name provided by the
+     * {@link Navigator.Name Navigator.Name annotation}, if any
      */
-    void addNavigator(@NonNull Navigator<? extends NavDestination> navigator);
+    @Nullable
+    Navigator<? extends NavDestination> addNavigator(
+            @NonNull Navigator<? extends NavDestination> navigator);
 
     /**
      * Register a navigator by name. {@link NavDestination destinations} may refer to any
@@ -65,6 +76,9 @@
      *
      * @param name name for this navigator
      * @param navigator navigator to add
+     * @return the previously added Navigator for the given name, if any
      */
-    void addNavigator(@NonNull String name, @NonNull Navigator<? extends NavDestination> navigator);
+    @Nullable
+    Navigator<? extends NavDestination> addNavigator(@NonNull String name,
+            @NonNull Navigator<? extends NavDestination> navigator);
 }
diff --git a/navigation/runtime/src/main/java/android/arch/navigation/SimpleNavigatorProvider.java b/navigation/runtime/src/main/java/android/arch/navigation/SimpleNavigatorProvider.java
index 514ec96..2e25943 100644
--- a/navigation/runtime/src/main/java/android/arch/navigation/SimpleNavigatorProvider.java
+++ b/navigation/runtime/src/main/java/android/arch/navigation/SimpleNavigatorProvider.java
@@ -17,6 +17,7 @@
 package android.arch.navigation;
 
 import android.support.annotation.NonNull;
+import android.support.annotation.Nullable;
 
 import java.util.HashMap;
 
@@ -28,9 +29,10 @@
     private final HashMap<String, Navigator<? extends NavDestination>> mNavigators =
             new HashMap<>();
 
+    @NonNull
     @Override
-    public Navigator<? extends NavDestination> getNavigator(
-            @NonNull Class<? extends Navigator> navigatorClass) {
+    public <D extends NavDestination, T extends Navigator<? extends D>> T getNavigator(
+            @NonNull Class<T> navigatorClass) {
         Navigator.Name annotation = navigatorClass.getAnnotation(Navigator.Name.class);
         String name = annotation != null ? annotation.value() : null;
         if (!validateName(name)) {
@@ -41,17 +43,27 @@
         return getNavigator(name);
     }
 
+    @NonNull
     @Override
-    public Navigator<? extends NavDestination> getNavigator(@NonNull String name) {
+    public <D extends NavDestination, T extends Navigator<? extends D>> T getNavigator(
+            @NonNull String name) {
         if (!validateName(name)) {
             throw new IllegalArgumentException("navigator name cannot be an empty string");
         }
 
-        return mNavigators.get(name);
+        Navigator<? extends NavDestination> navigator = mNavigators.get(name);
+        if (navigator == null) {
+            throw new IllegalStateException("Could not find Navigator with name \"" + name
+                    + "\". You must call NavController.addNavigator() for each navigation type.");
+        }
+        //noinspection unchecked
+        return (T) navigator;
     }
 
+    @Nullable
     @Override
-    public void addNavigator(@NonNull Navigator<? extends NavDestination> navigator) {
+    public Navigator<? extends NavDestination> addNavigator(
+            @NonNull Navigator<? extends NavDestination> navigator) {
         Navigator.Name annotation = navigator.getClass().getAnnotation(Navigator.Name.class);
         String name = annotation != null ? annotation.value() : null;
         if (!validateName(name)) {
@@ -59,16 +71,17 @@
                     + navigator.getClass().getSimpleName());
         }
 
-        addNavigator(name, navigator);
+        return addNavigator(name, navigator);
     }
 
+    @Nullable
     @Override
-    public void addNavigator(@NonNull String name,
+    public Navigator<? extends NavDestination> addNavigator(@NonNull String name,
             @NonNull Navigator<? extends NavDestination> navigator) {
         if (!validateName(name)) {
             throw new IllegalArgumentException("navigator name cannot be an empty string");
         }
-        mNavigators.put(name, navigator);
+        return mNavigators.put(name, navigator);
     }
 
     private boolean validateName(String name) {
diff --git a/navigation/runtime/src/test/java/android/arch/navigation/SimpleNavigatorProviderTest.java b/navigation/runtime/src/test/java/android/arch/navigation/SimpleNavigatorProviderTest.java
index 9d0dfba..b1c5777 100644
--- a/navigation/runtime/src/test/java/android/arch/navigation/SimpleNavigatorProviderTest.java
+++ b/navigation/runtime/src/test/java/android/arch/navigation/SimpleNavigatorProviderTest.java
@@ -17,8 +17,8 @@
 package android.arch.navigation;
 
 import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.Assert.fail;
 
 import android.support.test.filters.SmallTest;
 
@@ -30,11 +30,17 @@
 @RunWith(JUnit4.class)
 @SmallTest
 public class SimpleNavigatorProviderTest {
-    @Test(expected = IllegalArgumentException.class)
+    @Test
     public void addWithMissingAnnotationName() {
         SimpleNavigatorProvider provider = new SimpleNavigatorProvider();
         Navigator navigator = new NoNameNavigator();
-        provider.addNavigator(navigator);
+        try {
+            provider.addNavigator(navigator);
+            fail("Adding a provider with no @Navigator.Name should cause an "
+                    + "IllegalArgumentException");
+        } catch (IllegalArgumentException e) {
+            // Expected
+        }
     }
 
     @Test
@@ -51,15 +57,26 @@
         Navigator navigator = new EmptyNavigator();
         provider.addNavigator("name", navigator);
         assertThat(provider.getNavigator("name"), is(navigator));
-        assertThat(provider.getNavigator(EmptyNavigator.class), is(nullValue(Navigator.class)));
+        try {
+            provider.getNavigator(EmptyNavigator.class);
+            fail("getNavigator(Class) with an invalid name should cause an IllegalStateException");
+        } catch (IllegalStateException e) {
+            // Expected
+        }
     }
 
-    @Test(expected = IllegalArgumentException.class)
+    @Test
     public void addWithExplicitNameGetWithMissingAnnotationName() {
         SimpleNavigatorProvider provider = new SimpleNavigatorProvider();
         Navigator navigator = new NoNameNavigator();
         provider.addNavigator("name", navigator);
-        assertThat(provider.getNavigator(NoNameNavigator.class), is(navigator));
+        try {
+            provider.getNavigator(NoNameNavigator.class);
+            fail("getNavigator(Class) with no @Navigator.Name should cause "
+                    + "an IllegalArgumentException");
+        } catch (IllegalArgumentException e) {
+            // Expected
+        }
     }
 
     @Test