#4489: Add a shutil.rmtree that isn't suspectible to symlink attacks

It is used automatically on platforms supporting the necessary os.openat() and
os.unlinkat() functions. Main code by Martin von Löwis.
diff --git a/Doc/library/shutil.rst b/Doc/library/shutil.rst
index 7156116..c3eb990 100644
--- a/Doc/library/shutil.rst
+++ b/Doc/library/shutil.rst
@@ -190,14 +190,27 @@
    handled by calling a handler specified by *onerror* or, if that is omitted,
    they raise an exception.
 
+   .. warning::
+
+      The default :func:`rmtree` function is susceptible to a symlink attack:
+      given proper timing and circumstances, attackers can use it to delete
+      files they wouldn't be able to access otherwise.  Thus -- on platforms
+      that support the necessary fd-based functions :func:`os.openat` and
+      :func:`os.unlinkat` -- a safe version of :func:`rmtree` is used, which
+      isn't vulnerable.
+
    If *onerror* is provided, it must be a callable that accepts three
-   parameters: *function*, *path*, and *excinfo*. The first parameter,
-   *function*, is the function which raised the exception; it will be
-   :func:`os.path.islink`, :func:`os.listdir`, :func:`os.remove` or
-   :func:`os.rmdir`.  The second parameter, *path*, will be the path name passed
-   to *function*.  The third parameter, *excinfo*, will be the exception
-   information return by :func:`sys.exc_info`.  Exceptions raised by *onerror*
-   will not be caught.
+   parameters: *function*, *path*, and *excinfo*.
+
+   The first parameter, *function*, is the function which raised the exception;
+   it depends on the platform and implementation.  The second parameter,
+   *path*, will be the path name passed to *function*.  The third parameter,
+   *excinfo*, will be the exception information returned by
+   :func:`sys.exc_info`.  Exceptions raised by *onerror* will not be caught.
+
+   .. versionchanged:: 3.3
+      Added a safe version that is used automatically if platform supports
+      the fd-based functions :func:`os.openat` and :func:`os.unlinkat`.
 
 
 .. function:: move(src, dst)
diff --git a/Lib/shutil.py b/Lib/shutil.py
index 1d6971d..1b05484 100644
--- a/Lib/shutil.py
+++ b/Lib/shutil.py
@@ -337,23 +337,8 @@
         raise Error(errors)
     return dst
 
-def rmtree(path, ignore_errors=False, onerror=None):
-    """Recursively delete a directory tree.
-
-    If ignore_errors is set, errors are ignored; otherwise, if onerror
-    is set, it is called to handle the error with arguments (func,
-    path, exc_info) where func is os.listdir, os.remove, or os.rmdir;
-    path is the argument to that function that caused it to fail; and
-    exc_info is a tuple returned by sys.exc_info().  If ignore_errors
-    is false and onerror is None, an exception is raised.
-
-    """
-    if ignore_errors:
-        def onerror(*args):
-            pass
-    elif onerror is None:
-        def onerror(*args):
-            raise
+# version vulnerable to race conditions
+def _rmtree_unsafe(path, onerror):
     try:
         if os.path.islink(path):
             # symlinks to directories are forbidden, see bug #1669
@@ -374,7 +359,7 @@
         except os.error:
             mode = 0
         if stat.S_ISDIR(mode):
-            rmtree(fullname, ignore_errors, onerror)
+            _rmtree_unsafe(fullname, onerror)
         else:
             try:
                 os.remove(fullname)
@@ -385,6 +370,84 @@
     except os.error:
         onerror(os.rmdir, path, sys.exc_info())
 
+# Version using fd-based APIs to protect against races
+def _rmtree_safe_fd(topfd, path, onerror):
+    names = []
+    try:
+        names = os.flistdir(topfd)
+    except os.error:
+        onerror(os.flistdir, path, sys.exc_info())
+    for name in names:
+        fullname = os.path.join(path, name)
+        try:
+            orig_st = os.fstatat(topfd, name)
+            mode = orig_st.st_mode
+        except os.error:
+            mode = 0
+        if stat.S_ISDIR(mode):
+            try:
+                dirfd = os.openat(topfd, name, os.O_RDONLY)
+            except os.error:
+                onerror(os.openat, fullname, sys.exc_info())
+            else:
+                try:
+                    if os.path.samestat(orig_st, os.fstat(dirfd)):
+                        _rmtree_safe_fd(dirfd, fullname, onerror)
+                finally:
+                    os.close(dirfd)
+        else:
+            try:
+                os.unlinkat(topfd, name)
+            except os.error:
+                onerror(os.unlinkat, fullname, sys.exc_info())
+    try:
+        os.rmdir(path)
+    except os.error:
+        onerror(os.rmdir, path, sys.exc_info())
+
+_use_fd_functions = hasattr(os, 'openat') and hasattr(os, 'unlinkat')
+def rmtree(path, ignore_errors=False, onerror=None):
+    """Recursively delete a directory tree.
+
+    If ignore_errors is set, errors are ignored; otherwise, if onerror
+    is set, it is called to handle the error with arguments (func,
+    path, exc_info) where func is os.listdir, os.remove, or os.rmdir;
+    path is the argument to that function that caused it to fail; and
+    exc_info is a tuple returned by sys.exc_info().  If ignore_errors
+    is false and onerror is None, an exception is raised.
+
+    """
+    if ignore_errors:
+        def onerror(*args):
+            pass
+    elif onerror is None:
+        def onerror(*args):
+            raise
+    if _use_fd_functions:
+        # Note: To guard against symlink races, we use the standard
+        # lstat()/open()/fstat() trick.
+        try:
+            orig_st = os.lstat(path)
+        except Exception:
+            onerror(os.lstat, path, sys.exc_info())
+            return
+        try:
+            fd = os.open(path, os.O_RDONLY)
+        except Exception:
+            onerror(os.lstat, path, sys.exc_info())
+            return
+        try:
+            if (stat.S_ISDIR(orig_st.st_mode) and
+                os.path.samestat(orig_st, os.fstat(fd))):
+                _rmtree_safe_fd(fd, path, onerror)
+            elif (stat.S_ISREG(orig_st.st_mode)):
+                raise NotADirectoryError(20,
+                                         "Not a directory: '{}'".format(path))
+        finally:
+            os.close(fd)
+    else:
+        return _rmtree_unsafe(path, onerror)
+
 
 def _basename(path):
     # A basename() variant which first strips the trailing slash, if present.
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index 1929237..9c0c52c 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -120,29 +120,36 @@
         def test_on_error(self):
             self.errorState = 0
             os.mkdir(TESTFN)
-            self.childpath = os.path.join(TESTFN, 'a')
-            support.create_empty_file(self.childpath)
+            self.child_file_path = os.path.join(TESTFN, 'a')
+            self.child_dir_path = os.path.join(TESTFN, 'b')
+            support.create_empty_file(self.child_file_path)
+            os.mkdir(self.child_dir_path)
             old_dir_mode = os.stat(TESTFN).st_mode
-            old_child_mode = os.stat(self.childpath).st_mode
+            old_child_file_mode = os.stat(self.child_file_path).st_mode
+            old_child_dir_mode = os.stat(self.child_dir_path).st_mode
             # Make unwritable.
-            os.chmod(self.childpath, stat.S_IREAD)
-            os.chmod(TESTFN, stat.S_IREAD)
+            new_mode = stat.S_IREAD|stat.S_IEXEC
+            os.chmod(self.child_file_path, new_mode)
+            os.chmod(self.child_dir_path, new_mode)
+            os.chmod(TESTFN, new_mode)
 
             shutil.rmtree(TESTFN, onerror=self.check_args_to_onerror)
             # Test whether onerror has actually been called.
-            self.assertEqual(self.errorState, 2,
-                             "Expected call to onerror function did not happen.")
+            self.assertEqual(self.errorState, 3,
+                             "Expected call to onerror function did not "
+                             "happen.")
 
             # Make writable again.
             os.chmod(TESTFN, old_dir_mode)
-            os.chmod(self.childpath, old_child_mode)
+            os.chmod(self.child_file_path, old_child_file_mode)
+            os.chmod(self.child_dir_path, old_child_dir_mode)
 
             # Clean up.
             shutil.rmtree(TESTFN)
 
     def check_args_to_onerror(self, func, arg, exc):
         # test_rmtree_errors deliberately runs rmtree
-        # on a directory that is chmod 400, which will fail.
+        # on a directory that is chmod 500, which will fail.
         # This function is run when shutil.rmtree fails.
         # 99.9% of the time it initially fails to remove
         # a file in the directory, so the first time through
@@ -151,20 +158,39 @@
         # FUSE experienced a failure earlier in the process
         # at os.listdir.  The first failure may legally
         # be either.
-        if self.errorState == 0:
-            if func is os.remove:
-                self.assertEqual(arg, self.childpath)
+        if 0 <= self.errorState < 2:
+            if (func is os.remove or
+                hasattr(os, 'unlinkat') and func is os.unlinkat):
+                self.assertIn(arg, [self.child_file_path, self.child_dir_path])
             else:
-                self.assertIs(func, os.listdir,
-                              "func must be either os.remove or os.listdir")
-                self.assertEqual(arg, TESTFN)
+                if self.errorState == 1:
+                    self.assertEqual(func, os.rmdir)
+                else:
+                    self.assertIs(func, os.listdir, "func must be os.listdir")
+                self.assertIn(arg, [TESTFN, self.child_dir_path])
             self.assertTrue(issubclass(exc[0], OSError))
-            self.errorState = 1
+            self.errorState += 1
         else:
             self.assertEqual(func, os.rmdir)
             self.assertEqual(arg, TESTFN)
             self.assertTrue(issubclass(exc[0], OSError))
-            self.errorState = 2
+            self.errorState = 3
+
+    def test_rmtree_does_not_choke_on_failing_lstat(self):
+        try:
+            orig_lstat = os.lstat
+            def raiser(fn):
+                if fn != TESTFN:
+                    raise OSError()
+                else:
+                    return orig_lstat(fn)
+            os.lstat = raiser
+
+            os.mkdir(TESTFN)
+            write_file((TESTFN, 'foo'), 'foo')
+            shutil.rmtree(TESTFN)
+        finally:
+            os.lstat = orig_lstat
 
     @unittest.skipUnless(hasattr(os, 'chmod'), 'requires os.chmod')
     @support.skip_unless_symlink
@@ -464,7 +490,7 @@
         # When called on a file instead of a directory, don't delete it.
         handle, path = tempfile.mkstemp()
         os.close(handle)
-        self.assertRaises(OSError, shutil.rmtree, path)
+        self.assertRaises(NotADirectoryError, shutil.rmtree, path)
         os.remove(path)
 
     def test_copytree_simple(self):
@@ -629,6 +655,7 @@
             os.mkdir(src)
             os.symlink(src, dst)
             self.assertRaises(OSError, shutil.rmtree, dst)
+            shutil.rmtree(dst, ignore_errors=True)
         finally:
             shutil.rmtree(TESTFN, ignore_errors=True)
 
diff --git a/Misc/NEWS b/Misc/NEWS
index d76aeeb..718b76e 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -43,6 +43,10 @@
 Library
 -------
 
+- Issue #4489: Add a shutil.rmtree that isn't suspectible to symlink attacks.
+  It is used automatically on platforms supporting the necessary os.openat()
+  and os.unlinkat() functions. Main code by Martin von Löwis.
+
 - Issue #15114: the strict mode of HTMLParser and the HTMLParseError exception
   are deprecated now that the parser is able to parse invalid markup.