Merged revisions 80830 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r80830 | tarek.ziade | 2010-05-06 00:15:31 +0200 (Thu, 06 May 2010) | 1 line

  Fixed #4265: shutil.copyfile() was leaking file descriptors when disk fills
........
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index 8e75003..9f4871b 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -798,8 +798,112 @@
             shutil.rmtree(TESTFN, ignore_errors=True)
 
 
+class TestCopyFile(unittest.TestCase):
+
+    _delete = False
+
+    class Faux(object):
+        _entered = False
+        _exited_with = None
+        _raised = False
+        def __init__(self, raise_in_exit=False, suppress_at_exit=True):
+            self._raise_in_exit = raise_in_exit
+            self._suppress_at_exit = suppress_at_exit
+        def read(self, *args):
+            return ''
+        def __enter__(self):
+            self._entered = True
+        def __exit__(self, exc_type, exc_val, exc_tb):
+            self._exited_with = exc_type, exc_val, exc_tb
+            if self._raise_in_exit:
+                self._raised = True
+                raise IOError("Cannot close")
+            return self._suppress_at_exit
+
+    def tearDown(self):
+        if self._delete:
+            del shutil.open
+
+    def _set_shutil_open(self, func):
+        shutil.open = func
+        self._delete = True
+
+    def test_w_source_open_fails(self):
+        def _open(filename, mode='r'):
+            if filename == 'srcfile':
+                raise IOError('Cannot open "srcfile"')
+            assert 0  # shouldn't reach here.
+
+        self._set_shutil_open(_open)
+
+        self.assertRaises(IOError, shutil.copyfile, 'srcfile', 'destfile')
+
+    def test_w_dest_open_fails(self):
+
+        srcfile = self.Faux()
+
+        def _open(filename, mode='r'):
+            if filename == 'srcfile':
+                return srcfile
+            if filename == 'destfile':
+                raise IOError('Cannot open "destfile"')
+            assert 0  # shouldn't reach here.
+
+        self._set_shutil_open(_open)
+
+        shutil.copyfile('srcfile', 'destfile')
+        self.assertTrue(srcfile._entered)
+        self.assertTrue(srcfile._exited_with[0] is IOError)
+        self.assertEqual(srcfile._exited_with[1].args,
+                         ('Cannot open "destfile"',))
+
+    def test_w_dest_close_fails(self):
+
+        srcfile = self.Faux()
+        destfile = self.Faux(True)
+
+        def _open(filename, mode='r'):
+            if filename == 'srcfile':
+                return srcfile
+            if filename == 'destfile':
+                return destfile
+            assert 0  # shouldn't reach here.
+
+        self._set_shutil_open(_open)
+
+        shutil.copyfile('srcfile', 'destfile')
+        self.assertTrue(srcfile._entered)
+        self.assertTrue(destfile._entered)
+        self.assertTrue(destfile._raised)
+        self.assertTrue(srcfile._exited_with[0] is IOError)
+        self.assertEqual(srcfile._exited_with[1].args,
+                         ('Cannot close',))
+
+    def test_w_source_close_fails(self):
+
+        srcfile = self.Faux(True)
+        destfile = self.Faux()
+
+        def _open(filename, mode='r'):
+            if filename == 'srcfile':
+                return srcfile
+            if filename == 'destfile':
+                return destfile
+            assert 0  # shouldn't reach here.
+
+        self._set_shutil_open(_open)
+
+        self.assertRaises(IOError,
+                          shutil.copyfile, 'srcfile', 'destfile')
+        self.assertTrue(srcfile._entered)
+        self.assertTrue(destfile._entered)
+        self.assertFalse(destfile._raised)
+        self.assertTrue(srcfile._exited_with[0] is None)
+        self.assertTrue(srcfile._raised)
+
+
 def test_main():
-    support.run_unittest(TestShutil, TestMove)
+    support.run_unittest(TestShutil, TestMove, TestCopyFile)
 
 if __name__ == '__main__':
     test_main()