Merged revisions 72905 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r72905 | benjamin.peterson | 2009-05-24 19:48:58 -0500 (Sun, 24 May 2009) | 4 lines

  make class skipping decorators the same as skipping every test of the class

  This removes ClassTestSuite and a good bit of hacks.
........
diff --git a/Lib/test/test_unittest.py b/Lib/test/test_unittest.py
index e7097cc..ea33180 100644
--- a/Lib/test/test_unittest.py
+++ b/Lib/test/test_unittest.py
@@ -106,7 +106,7 @@
 # List subclass we can add attributes to.
 class MyClassSuite(list):
 
-    def __init__(self, tests, klass):
+    def __init__(self, tests):
         super(MyClassSuite, self).__init__(tests)
 
 
@@ -1271,7 +1271,7 @@
         tests = [Foo('test_1'), Foo('test_2')]
 
         loader = unittest.TestLoader()
-        loader.classSuiteClass = MyClassSuite
+        loader.suiteClass = list
         self.assertEqual(loader.loadTestsFromTestCase(Foo), tests)
 
     # It is implicit in the documentation for TestLoader.suiteClass that
@@ -1284,7 +1284,7 @@
             def foo_bar(self): pass
         m.Foo = Foo
 
-        tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)]
+        tests = [[Foo('test_1'), Foo('test_2')]]
 
         loader = unittest.TestLoader()
         loader.suiteClass = list
@@ -1303,7 +1303,7 @@
         tests = [Foo('test_1'), Foo('test_2')]
 
         loader = unittest.TestLoader()
-        loader.classSuiteClass = MyClassSuite
+        loader.suiteClass = list
         self.assertEqual(loader.loadTestsFromName('Foo', m), tests)
 
     # It is implicit in the documentation for TestLoader.suiteClass that
@@ -1316,7 +1316,7 @@
             def foo_bar(self): pass
         m.Foo = Foo
 
-        tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)]
+        tests = [[Foo('test_1'), Foo('test_2')]]
 
         loader = unittest.TestLoader()
         loader.suiteClass = list
@@ -2842,7 +2842,7 @@
                 def test_dont_skip(self): pass
             test_do_skip = Foo("test_skip")
             test_dont_skip = Foo("test_dont_skip")
-            suite = unittest.ClassTestSuite([test_do_skip, test_dont_skip], Foo)
+            suite = unittest.TestSuite([test_do_skip, test_dont_skip])
             events = []
             result = LoggingResult(events)
             suite.run(result)
@@ -2861,9 +2861,10 @@
                 record.append(1)
         record = []
         result = unittest.TestResult()
-        suite = unittest.ClassTestSuite([Foo("test_1")], Foo)
+        test = Foo("test_1")
+        suite = unittest.TestSuite([test])
         suite.run(result)
-        self.assertEqual(result.skipped, [(suite, "testing")])
+        self.assertEqual(result.skipped, [(test, "testing")])
         self.assertEqual(record, [])
 
     def test_expected_failure(self):
diff --git a/Lib/unittest.py b/Lib/unittest.py
index c6d893e..cdccd8c 100644
--- a/Lib/unittest.py
+++ b/Lib/unittest.py
@@ -59,7 +59,7 @@
 ##############################################################################
 # Exported classes and functions
 ##############################################################################
-__all__ = ['TestResult', 'TestCase', 'TestSuite', 'ClassTestSuite',
+__all__ = ['TestResult', 'TestCase', 'TestSuite',
            'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main',
            'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
            'expectedFailure']
@@ -459,6 +459,13 @@
 
         self._result = result
         result.startTest(self)
+        if getattr(self.__class__, "__unittest_skip__", False):
+            # If the whole class was skipped.
+            try:
+                result.addSkip(self, self.__class__.__unittest_skip_why__)
+            finally:
+                result.stopTest(self)
+            return
         testMethod = getattr(self, self._testMethodName)
         try:
             success = False
@@ -1129,37 +1136,6 @@
             test.debug()
 
 
-class ClassTestSuite(TestSuite):
-    """
-    Suite of tests derived from a single TestCase class.
-    """
-
-    def __init__(self, tests, class_collected_from):
-        super(ClassTestSuite, self).__init__(tests)
-        self.collected_from = class_collected_from
-
-    def id(self):
-        module = getattr(self.collected_from, "__module__", None)
-        if module is not None:
-            return "{0}.{1}".format(module, self.collected_from.__name__)
-        return self.collected_from.__name__
-
-    def run(self, result):
-        if getattr(self.collected_from, "__unittest_skip__", False):
-            # ClassTestSuite result pretends to be a TestCase enough to be
-            # reported.
-            result.startTest(self)
-            try:
-                result.addSkip(self, self.collected_from.__unittest_skip_why__)
-            finally:
-                result.stopTest(self)
-        else:
-            result = super(ClassTestSuite, self).run(result)
-        return result
-
-    shortDescription = id
-
-
 class FunctionTestCase(TestCase):
     """A test case that wraps a test function.
 
@@ -1245,7 +1221,6 @@
     testMethodPrefix = 'test'
     sortTestMethodsUsing = staticmethod(three_way_cmp)
     suiteClass = TestSuite
-    classSuiteClass = ClassTestSuite
 
     def loadTestsFromTestCase(self, testCaseClass):
         """Return a suite of all tests cases contained in testCaseClass"""
@@ -1255,8 +1230,7 @@
         testCaseNames = self.getTestCaseNames(testCaseClass)
         if not testCaseNames and hasattr(testCaseClass, 'runTest'):
             testCaseNames = ['runTest']
-        suite = self.classSuiteClass(map(testCaseClass, testCaseNames),
-                                     testCaseClass)
+        suite = self.suiteClass(map(testCaseClass, testCaseNames))
         return suite
 
     def loadTestsFromModule(self, module):