bpo-45234: Fix FileNotFound exception raised instead of IsADirectoryError in shutil.copyfile() (GH-28421) (GH-28508)
This was a regression from fixing BPO-43219.
(cherry picked from commit b7eac52b466f697d3e89f47508e0df0196a98970)
Co-authored-by: andrei kulakov <andrei.avk@gmail.com>
diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py
index 7bf60fd..7669b94 100644
--- a/Lib/test/test_shutil.py
+++ b/Lib/test/test_shutil.py
@@ -1151,6 +1151,28 @@ def test_copy_return_value(self):
rv = fn(src, os.path.join(dst_dir, 'bar'))
self.assertEqual(rv, os.path.join(dst_dir, 'bar'))
+ def test_copy_dir(self):
+ self._test_copy_dir(shutil.copy)
+
+ def test_copy2_dir(self):
+ self._test_copy_dir(shutil.copy2)
+
+ def _test_copy_dir(self, copy_func):
+ src_dir = self.mkdtemp()
+ src_file = os.path.join(src_dir, 'foo')
+ dir2 = self.mkdtemp()
+ dst = os.path.join(src_dir, 'does_not_exist/')
+ write_file(src_file, 'foo')
+ if sys.platform == "win32":
+ err = PermissionError
+ else:
+ err = IsADirectoryError
+ self.assertRaises(err, copy_func, dir2, src_dir)
+
+ # raise *err* because of src rather than FileNotFoundError because of dst
+ self.assertRaises(err, copy_func, dir2, dst)
+ copy_func(src_file, dir2) # should not raise exceptions
+
### shutil.copyfile
@os_helper.skip_unless_symlink
@@ -1259,6 +1281,24 @@ def test_copyfile_nonexistent_dir(self):
write_file(src_file, 'foo')
self.assertRaises(FileNotFoundError, shutil.copyfile, src_file, dst)
+ def test_copyfile_copy_dir(self):
+ # Issue 45234
+ # test copy() and copyfile() raising proper exceptions when src and/or
+ # dst are directories
+ src_dir = self.mkdtemp()
+ src_file = os.path.join(src_dir, 'foo')
+ dir2 = self.mkdtemp()
+ dst = os.path.join(src_dir, 'does_not_exist/')
+ write_file(src_file, 'foo')
+ if sys.platform == "win32":
+ err = PermissionError
+ else:
+ err = IsADirectoryError
+
+ self.assertRaises(err, shutil.copyfile, src_dir, dst)
+ self.assertRaises(err, shutil.copyfile, src_file, src_dir)
+ self.assertRaises(err, shutil.copyfile, dir2, src_dir)
+
class TestArchives(BaseTest, unittest.TestCase):