Close #22457: Honour load_tests in the start_dir of discovery.

We were not honouring load_tests in a package/__init__.py when that was the
start_dir parameter, though we do when it is a child package. The fix required
a little care since it introduces the possibility of infinite recursion.
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index 811bedf..8c10ad1 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -65,6 +65,9 @@
     def __init__(self):
         super(TestLoader, self).__init__()
         self.errors = []
+        # Tracks packages which we have called into via load_tests, to
+        # avoid infinite re-entrancy.
+        self._loading_packages = set()
 
     def loadTestsFromTestCase(self, testCaseClass):
         """Return a suite of all tests cases contained in testCaseClass"""
@@ -229,9 +232,13 @@
 
         If a test package name (directory with '__init__.py') matches the
         pattern then the package will be checked for a 'load_tests' function. If
-        this exists then it will be called with loader, tests, pattern.
+        this exists then it will be called with (loader, tests, pattern) unless
+        the package has already had load_tests called from the same discovery
+        invocation, in which case the package module object is not scanned for
+        tests - this ensures that when a package uses discover to further
+        discover child tests that infinite recursion does not happen.
 
-        If load_tests exists then discovery does  *not* recurse into the package,
+        If load_tests exists then discovery does *not* recurse into the package,
         load_tests is responsible for loading all tests in the package.
 
         The pattern is deliberately not stored as a loader attribute so that
@@ -355,69 +362,110 @@
 
     def _find_tests(self, start_dir, pattern, namespace=False):
         """Used by discovery. Yields test suites it loads."""
+        # Handle the __init__ in this package
+        name = self._get_name_from_path(start_dir)
+        # name is '.' when start_dir == top_level_dir (and top_level_dir is by
+        # definition not a package).
+        if name != '.' and name not in self._loading_packages:
+            # name is in self._loading_packages while we have called into
+            # loadTestsFromModule with name.
+            tests, should_recurse = self._find_test_path(
+                start_dir, pattern, namespace)
+            if tests is not None:
+                yield tests
+            if not should_recurse:
+                # Either an error occured, or load_tests was used by the
+                # package.
+                return
+        # Handle the contents.
         paths = sorted(os.listdir(start_dir))
-
         for path in paths:
             full_path = os.path.join(start_dir, path)
-            if os.path.isfile(full_path):
-                if not VALID_MODULE_NAME.match(path):
-                    # valid Python identifiers only
-                    continue
-                if not self._match_path(path, full_path, pattern):
-                    continue
-                # if the test file matches, load it
+            tests, should_recurse = self._find_test_path(
+                full_path, pattern, namespace)
+            if tests is not None:
+                yield tests
+            if should_recurse:
+                # we found a package that didn't use load_tests.
                 name = self._get_name_from_path(full_path)
+                self._loading_packages.add(name)
                 try:
-                    module = self._get_module_from_name(name)
-                except case.SkipTest as e:
-                    yield _make_skipped_test(name, e, self.suiteClass)
-                except:
-                    error_case, error_message = \
-                        _make_failed_import_test(name, self.suiteClass)
-                    self.errors.append(error_message)
-                    yield error_case
-                else:
-                    mod_file = os.path.abspath(getattr(module, '__file__', full_path))
-                    realpath = _jython_aware_splitext(os.path.realpath(mod_file))
-                    fullpath_noext = _jython_aware_splitext(os.path.realpath(full_path))
-                    if realpath.lower() != fullpath_noext.lower():
-                        module_dir = os.path.dirname(realpath)
-                        mod_name = _jython_aware_splitext(os.path.basename(full_path))
-                        expected_dir = os.path.dirname(full_path)
-                        msg = ("%r module incorrectly imported from %r. Expected %r. "
-                               "Is this module globally installed?")
-                        raise ImportError(msg % (mod_name, module_dir, expected_dir))
-                    yield self.loadTestsFromModule(module, pattern=pattern)
-            elif os.path.isdir(full_path):
-                if (not namespace and
-                    not os.path.isfile(os.path.join(full_path, '__init__.py'))):
-                    continue
+                    yield from self._find_tests(full_path, pattern, namespace)
+                finally:
+                    self._loading_packages.discard(name)
 
-                load_tests = None
-                tests = None
-                name = self._get_name_from_path(full_path)
+    def _find_test_path(self, full_path, pattern, namespace=False):
+        """Used by discovery.
+
+        Loads tests from a single file, or a directories' __init__.py when
+        passed the directory.
+
+        Returns a tuple (None_or_tests_from_file, should_recurse).
+        """
+        basename = os.path.basename(full_path)
+        if os.path.isfile(full_path):
+            if not VALID_MODULE_NAME.match(basename):
+                # valid Python identifiers only
+                return None, False
+            if not self._match_path(basename, full_path, pattern):
+                return None, False
+            # if the test file matches, load it
+            name = self._get_name_from_path(full_path)
+            try:
+                module = self._get_module_from_name(name)
+            except case.SkipTest as e:
+                return _make_skipped_test(name, e, self.suiteClass), False
+            except:
+                error_case, error_message = \
+                    _make_failed_import_test(name, self.suiteClass)
+                self.errors.append(error_message)
+                return error_case, False
+            else:
+                mod_file = os.path.abspath(
+                    getattr(module, '__file__', full_path))
+                realpath = _jython_aware_splitext(
+                    os.path.realpath(mod_file))
+                fullpath_noext = _jython_aware_splitext(
+                    os.path.realpath(full_path))
+                if realpath.lower() != fullpath_noext.lower():
+                    module_dir = os.path.dirname(realpath)
+                    mod_name = _jython_aware_splitext(
+                        os.path.basename(full_path))
+                    expected_dir = os.path.dirname(full_path)
+                    msg = ("%r module incorrectly imported from %r. Expected "
+                           "%r. Is this module globally installed?")
+                    raise ImportError(
+                        msg % (mod_name, module_dir, expected_dir))
+                return self.loadTestsFromModule(module, pattern=pattern), False
+        elif os.path.isdir(full_path):
+            if (not namespace and
+                not os.path.isfile(os.path.join(full_path, '__init__.py'))):
+                return None, False
+
+            load_tests = None
+            tests = None
+            name = self._get_name_from_path(full_path)
+            try:
+                package = self._get_module_from_name(name)
+            except case.SkipTest as e:
+                return _make_skipped_test(name, e, self.suiteClass), False
+            except:
+                error_case, error_message = \
+                    _make_failed_import_test(name, self.suiteClass)
+                self.errors.append(error_message)
+                return error_case, False
+            else:
+                load_tests = getattr(package, 'load_tests', None)
+                # Mark this package as being in load_tests (possibly ;))
+                self._loading_packages.add(name)
                 try:
-                    package = self._get_module_from_name(name)
-                except case.SkipTest as e:
-                    yield _make_skipped_test(name, e, self.suiteClass)
-                except:
-                    error_case, error_message = \
-                        _make_failed_import_test(name, self.suiteClass)
-                    self.errors.append(error_message)
-                    yield error_case
-                else:
-                    load_tests = getattr(package, 'load_tests', None)
                     tests = self.loadTestsFromModule(package, pattern=pattern)
-                    if tests is not None:
-                        # tests loaded from package file
-                        yield tests
-
                     if load_tests is not None:
-                        # loadTestsFromModule(package) has load_tests for us.
-                        continue
-                    # recurse into the package
-                    yield from self._find_tests(full_path, pattern,
-                                                namespace=namespace)
+                        # loadTestsFromModule(package) has loaded tests for us.
+                        return tests, False
+                    return tests, True
+                finally:
+                    self._loading_packages.discard(name)
 
 
 defaultTestLoader = TestLoader()