#4489: Add a shutil.rmtree that isn't suspectible to symlink attacks

It is used automatically on platforms supporting the necessary os.openat() and
os.unlinkat() functions. Main code by Martin von Löwis.
diff --git a/Lib/shutil.py b/Lib/shutil.py
index 1d6971d..1b05484 100644
--- a/Lib/shutil.py
+++ b/Lib/shutil.py
@@ -337,23 +337,8 @@
         raise Error(errors)
     return dst
 
-def rmtree(path, ignore_errors=False, onerror=None):
-    """Recursively delete a directory tree.
-
-    If ignore_errors is set, errors are ignored; otherwise, if onerror
-    is set, it is called to handle the error with arguments (func,
-    path, exc_info) where func is os.listdir, os.remove, or os.rmdir;
-    path is the argument to that function that caused it to fail; and
-    exc_info is a tuple returned by sys.exc_info().  If ignore_errors
-    is false and onerror is None, an exception is raised.
-
-    """
-    if ignore_errors:
-        def onerror(*args):
-            pass
-    elif onerror is None:
-        def onerror(*args):
-            raise
+# version vulnerable to race conditions
+def _rmtree_unsafe(path, onerror):
     try:
         if os.path.islink(path):
             # symlinks to directories are forbidden, see bug #1669
@@ -374,7 +359,7 @@
         except os.error:
             mode = 0
         if stat.S_ISDIR(mode):
-            rmtree(fullname, ignore_errors, onerror)
+            _rmtree_unsafe(fullname, onerror)
         else:
             try:
                 os.remove(fullname)
@@ -385,6 +370,84 @@
     except os.error:
         onerror(os.rmdir, path, sys.exc_info())
 
+# Version using fd-based APIs to protect against races
+def _rmtree_safe_fd(topfd, path, onerror):
+    names = []
+    try:
+        names = os.flistdir(topfd)
+    except os.error:
+        onerror(os.flistdir, path, sys.exc_info())
+    for name in names:
+        fullname = os.path.join(path, name)
+        try:
+            orig_st = os.fstatat(topfd, name)
+            mode = orig_st.st_mode
+        except os.error:
+            mode = 0
+        if stat.S_ISDIR(mode):
+            try:
+                dirfd = os.openat(topfd, name, os.O_RDONLY)
+            except os.error:
+                onerror(os.openat, fullname, sys.exc_info())
+            else:
+                try:
+                    if os.path.samestat(orig_st, os.fstat(dirfd)):
+                        _rmtree_safe_fd(dirfd, fullname, onerror)
+                finally:
+                    os.close(dirfd)
+        else:
+            try:
+                os.unlinkat(topfd, name)
+            except os.error:
+                onerror(os.unlinkat, fullname, sys.exc_info())
+    try:
+        os.rmdir(path)
+    except os.error:
+        onerror(os.rmdir, path, sys.exc_info())
+
+_use_fd_functions = hasattr(os, 'openat') and hasattr(os, 'unlinkat')
+def rmtree(path, ignore_errors=False, onerror=None):
+    """Recursively delete a directory tree.
+
+    If ignore_errors is set, errors are ignored; otherwise, if onerror
+    is set, it is called to handle the error with arguments (func,
+    path, exc_info) where func is os.listdir, os.remove, or os.rmdir;
+    path is the argument to that function that caused it to fail; and
+    exc_info is a tuple returned by sys.exc_info().  If ignore_errors
+    is false and onerror is None, an exception is raised.
+
+    """
+    if ignore_errors:
+        def onerror(*args):
+            pass
+    elif onerror is None:
+        def onerror(*args):
+            raise
+    if _use_fd_functions:
+        # Note: To guard against symlink races, we use the standard
+        # lstat()/open()/fstat() trick.
+        try:
+            orig_st = os.lstat(path)
+        except Exception:
+            onerror(os.lstat, path, sys.exc_info())
+            return
+        try:
+            fd = os.open(path, os.O_RDONLY)
+        except Exception:
+            onerror(os.lstat, path, sys.exc_info())
+            return
+        try:
+            if (stat.S_ISDIR(orig_st.st_mode) and
+                os.path.samestat(orig_st, os.fstat(fd))):
+                _rmtree_safe_fd(fd, path, onerror)
+            elif (stat.S_ISREG(orig_st.st_mode)):
+                raise NotADirectoryError(20,
+                                         "Not a directory: '{}'".format(path))
+        finally:
+            os.close(fd)
+    else:
+        return _rmtree_unsafe(path, onerror)
+
 
 def _basename(path):
     # A basename() variant which first strips the trailing slash, if present.