Merged revisions 70555 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r70555 | benjamin.peterson | 2009-03-23 16:50:21 -0500 (Mon, 23 Mar 2009) | 4 lines

  implement test skipping and expected failures

  patch by myself #1034053
........
diff --git a/Lib/test/test_unittest.py b/Lib/test/test_unittest.py
index 38ceb9a..74aff14 100644
--- a/Lib/test/test_unittest.py
+++ b/Lib/test/test_unittest.py
@@ -31,10 +31,27 @@
         self._events.append('addFailure')
         super().addFailure(*args)
 
+    def addSuccess(self, *args):
+        self._events.append('addSuccess')
+        super(LoggingResult, self).addSuccess(*args)
+
     def addError(self, *args):
         self._events.append('addError')
         super().addError(*args)
 
+    def addSkip(self, *args):
+        self._events.append('addSkip')
+        super(LoggingResult, self).addSkip(*args)
+
+    def addExpectedFailure(self, *args):
+        self._events.append('addExpectedFailure')
+        super(LoggingResult, self).addExpectedFailure(*args)
+
+    def addUnexpectedSuccess(self, *args):
+        self._events.append('addUnexpectedSuccess')
+        super(LoggingResult, self).addUnexpectedSuccess(*args)
+
+
 class TestEquality(object):
     # Check for a valid __eq__ implementation
     def test_eq(self):
@@ -72,6 +89,13 @@
                 self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e))
 
 
+# List subclass we can add attributes to.
+class MyClassSuite(list):
+
+    def __init__(self, tests, klass):
+        super(MyClassSuite, self).__init__(tests)
+
+
 ################################################################
 ### /Support code
 
@@ -1233,7 +1257,7 @@
         tests = [Foo('test_1'), Foo('test_2')]
 
         loader = unittest.TestLoader()
-        loader.suiteClass = list
+        loader.classSuiteClass = MyClassSuite
         self.assertEqual(loader.loadTestsFromTestCase(Foo), tests)
 
     # It is implicit in the documentation for TestLoader.suiteClass that
@@ -1246,7 +1270,7 @@
             def foo_bar(self): pass
         m.Foo = Foo
 
-        tests = [[Foo('test_1'), Foo('test_2')]]
+        tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)]
 
         loader = unittest.TestLoader()
         loader.suiteClass = list
@@ -1265,7 +1289,7 @@
         tests = [Foo('test_1'), Foo('test_2')]
 
         loader = unittest.TestLoader()
-        loader.suiteClass = list
+        loader.classSuiteClass = MyClassSuite
         self.assertEqual(loader.loadTestsFromName('Foo', m), tests)
 
     # It is implicit in the documentation for TestLoader.suiteClass that
@@ -1278,7 +1302,7 @@
             def foo_bar(self): pass
         m.Foo = Foo
 
-        tests = [[Foo('test_1'), Foo('test_2')]]
+        tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)]
 
         loader = unittest.TestLoader()
         loader.suiteClass = list
@@ -2271,9 +2295,103 @@
         # Make run() find a result object on its own
         Foo('test').run()
 
-        expected = ['startTest', 'test', 'stopTest']
+        expected = ['startTest', 'test', 'addSuccess', 'stopTest']
         self.assertEqual(events, expected)
 
+
+class Test_TestSkipping(TestCase):
+
+    def test_skipping(self):
+        class Foo(unittest.TestCase):
+            def test_skip_me(self):
+                self.skip("skip")
+        events = []
+        result = LoggingResult(events)
+        test = Foo("test_skip_me")
+        test.run(result)
+        self.assertEqual(events, ['startTest', 'addSkip', 'stopTest'])
+        self.assertEqual(result.skipped, [(test, "skip")])
+
+        # Try letting setUp skip the test now.
+        class Foo(unittest.TestCase):
+            def setUp(self):
+                self.skip("testing")
+            def test_nothing(self): pass
+        events = []
+        result = LoggingResult(events)
+        test = Foo("test_nothing")
+        test.run(result)
+        self.assertEqual(events, ['startTest', 'addSkip', 'stopTest'])
+        self.assertEqual(result.skipped, [(test, "testing")])
+        self.assertEqual(result.testsRun, 1)
+
+    def test_skipping_decorators(self):
+        op_table = ((unittest.skipUnless, False, True),
+                    (unittest.skipIf, True, False))
+        for deco, do_skip, dont_skip in op_table:
+            class Foo(unittest.TestCase):
+                @deco(do_skip, "testing")
+                def test_skip(self): pass
+
+                @deco(dont_skip, "testing")
+                def test_dont_skip(self): pass
+            test_do_skip = Foo("test_skip")
+            test_dont_skip = Foo("test_dont_skip")
+            suite = unittest.ClassTestSuite([test_do_skip, test_dont_skip], Foo)
+            events = []
+            result = LoggingResult(events)
+            suite.run(result)
+            self.assertEqual(len(result.skipped), 1)
+            expected = ['startTest', 'addSkip', 'stopTest',
+                        'startTest', 'addSuccess', 'stopTest']
+            self.assertEqual(events, expected)
+            self.assertEqual(result.testsRun, 2)
+            self.assertEqual(result.skipped, [(test_do_skip, "testing")])
+            self.assertTrue(result.wasSuccessful())
+
+    def test_skip_class(self):
+        @unittest.skip("testing")
+        class Foo(unittest.TestCase):
+            def test_1(self):
+                record.append(1)
+        record = []
+        result = unittest.TestResult()
+        suite = unittest.ClassTestSuite([Foo("test_1")], Foo)
+        suite.run(result)
+        self.assertEqual(result.skipped, [(suite, "testing")])
+        self.assertEqual(record, [])
+
+    def test_expected_failure(self):
+        class Foo(unittest.TestCase):
+            @unittest.expectedFailure
+            def test_die(self):
+                self.fail("help me!")
+        events = []
+        result = LoggingResult(events)
+        test = Foo("test_die")
+        test.run(result)
+        self.assertEqual(events,
+                         ['startTest', 'addExpectedFailure', 'stopTest'])
+        self.assertEqual(result.expected_failures[0][0], test)
+        self.assertTrue(result.wasSuccessful())
+
+    def test_unexpected_success(self):
+        class Foo(unittest.TestCase):
+            @unittest.expectedFailure
+            def test_die(self):
+                pass
+        events = []
+        result = LoggingResult(events)
+        test = Foo("test_die")
+        test.run(result)
+        self.assertEqual(events,
+                         ['startTest', 'addUnexpectedSuccess', 'stopTest'])
+        self.assertFalse(result.failures)
+        self.assertEqual(result.unexpected_successes, [test])
+        self.assertTrue(result.wasSuccessful())
+
+
+
 class Test_Assertions(TestCase):
     def test_AlmostEqual(self):
         self.failUnlessAlmostEqual(1.00000001, 1.0)
@@ -2338,7 +2456,7 @@
 def test_main():
     support.run_unittest(Test_TestCase, Test_TestLoader,
         Test_TestSuite, Test_TestResult, Test_FunctionTestCase,
-        Test_Assertions)
+        Test_TestSkipping, Test_Assertions)
 
 if __name__ == "__main__":
     test_main()
diff --git a/Lib/unittest.py b/Lib/unittest.py
index aa6eb65..ec11328 100644
--- a/Lib/unittest.py
+++ b/Lib/unittest.py
@@ -53,6 +53,7 @@
 import traceback
 import os
 import types
+import functools
 
 ##############################################################################
 # Exported classes and functions
@@ -71,6 +72,79 @@
 def _strclass(cls):
     return "%s.%s" % (cls.__module__, cls.__name__)
 
+
+class SkipTest(Exception):
+    """
+    Raise this exception in a test to skip it.
+
+    Usually you can use TestResult.skip() or one of the skipping decorators
+    instead of raising this directly.
+    """
+    pass
+
+class _ExpectedFailure(Exception):
+    """
+    Raise this when a test is expected to fail.
+
+    This is an implementation detail.
+    """
+
+    def __init__(self, exc_info):
+        super(_ExpectedFailure, self).__init__()
+        self.exc_info = exc_info
+
+class _UnexpectedSuccess(Exception):
+    """
+    The test was supposed to fail, but it didn't!
+    """
+    pass
+
+def _id(obj):
+    return obj
+
+def skip(reason):
+    """
+    Unconditionally skip a test.
+    """
+    def decorator(test_item):
+        if isinstance(test_item, type) and issubclass(test_item, TestCase):
+            test_item.__unittest_skip__ = True
+            test_item.__unittest_skip_why__ = reason
+            return test_item
+        @functools.wraps(test_item)
+        def skip_wrapper(*args, **kwargs):
+            raise SkipTest(reason)
+        return skip_wrapper
+    return decorator
+
+def skipIf(condition, reason):
+    """
+    Skip a test if the condition is true.
+    """
+    if condition:
+        return skip(reason)
+    return _id
+
+def skipUnless(condition, reason):
+    """
+    Skip a test unless the condition is true.
+    """
+    if not condition:
+        return skip(reason)
+    return _id
+
+
+def expectedFailure(func):
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except Exception:
+            raise _ExpectedFailure(sys.exc_info())
+        raise _UnexpectedSuccess
+    return wrapper
+
+
 __unittest = 1
 
 class TestResult(object):
@@ -88,6 +162,9 @@
         self.failures = []
         self.errors = []
         self.testsRun = 0
+        self.skipped = []
+        self.expected_failures = []
+        self.unexpected_successes = []
         self.shouldStop = False
 
     def startTest(self, test):
@@ -113,6 +190,19 @@
         "Called when a test has completed successfully"
         pass
 
+    def addSkip(self, test, reason):
+        """Called when a test is skipped."""
+        self.skipped.append((test, reason))
+
+    def addExpectedFailure(self, test, err):
+        """Called when an expected failure/error occured."""
+        self.expected_failures.append(
+            (test, self._exc_info_to_string(err, test)))
+
+    def addUnexpectedSuccess(self, test):
+        """Called when a test was expected to fail, but succeed."""
+        self.unexpected_successes.append(test)
+
     def wasSuccessful(self):
         "Tells whether or not this result was a success"
         return len(self.failures) == len(self.errors) == 0
@@ -273,25 +363,36 @@
         try:
             try:
                 self.setUp()
+            except SkipTest as e:
+                result.addSkip(self, str(e))
+                return
             except Exception:
                 result.addError(self, self._exc_info())
                 return
 
-            ok = False
+            success = False
             try:
                 testMethod()
-                ok = True
             except self.failureException:
                 result.addFailure(self, self._exc_info())
+            except _ExpectedFailure as e:
+                result.addExpectedFailure(self, e.exc_info)
+            except _UnexpectedSuccess:
+                result.addUnexpectedSuccess(self)
+            except SkipTest as e:
+                result.addSkip(self, str(e))
             except Exception:
                 result.addError(self, self._exc_info())
+            else:
+                success = True
 
             try:
                 self.tearDown()
             except Exception:
                 result.addError(self, self._exc_info())
-                ok = False
-            if ok: result.addSuccess(self)
+                success = False
+            if success:
+                result.addSuccess(self)
         finally:
             result.stopTest(self)
 
@@ -311,6 +412,10 @@
         """
         return sys.exc_info()
 
+    def skip(self, reason):
+        """Skip this test."""
+        raise SkipTest(reason)
+
     def fail(self, msg=None):
         """Fail immediately, with the given message."""
         raise self.failureException(msg)
@@ -418,8 +523,8 @@
     __str__ = __repr__
 
     def __eq__(self, other):
-        if type(self) is not type(other):
-            return False
+        if not isinstance(other, self.__class__):
+            return NotImplemented
         return self._tests == other._tests
 
     def __ne__(self, other):
@@ -464,6 +569,37 @@
         for test in self._tests: test.debug()
 
 
+class ClassTestSuite(TestSuite):
+    """
+    Suite of tests derived from a single TestCase class.
+    """
+
+    def __init__(self, tests, class_collected_from):
+        super(ClassTestSuite, self).__init__(tests)
+        self.collected_from = class_collected_from
+
+    def id(self):
+        module = getattr(self.collected_from, "__module__", None)
+        if module is not None:
+            return "{0}.{1}".format(module, self.collected_from.__name__)
+        return self.collected_from.__name__
+
+    def run(self, result):
+        if getattr(self.collected_from, "__unittest_skip__", False):
+            # ClassTestSuite result pretends to be a TestCase enough to be
+            # reported.
+            result.startTest(self)
+            try:
+                result.addSkip(self, self.collected_from.__unittest_skip_why__)
+            finally:
+                result.stopTest(self)
+        else:
+            result = super(ClassTestSuite, self).run(result)
+        return result
+
+    shortDescription = id
+
+
 class FunctionTestCase(TestCase):
     """A test case that wraps a test function.
 
@@ -550,6 +686,7 @@
     testMethodPrefix = 'test'
     sortTestMethodsUsing = staticmethod(three_way_cmp)
     suiteClass = TestSuite
+    classSuiteClass = ClassTestSuite
 
     def loadTestsFromTestCase(self, testCaseClass):
         """Return a suite of all tests cases contained in testCaseClass"""
@@ -559,7 +696,9 @@
         testCaseNames = self.getTestCaseNames(testCaseClass)
         if not testCaseNames and hasattr(testCaseClass, 'runTest'):
             testCaseNames = ['runTest']
-        return self.suiteClass(map(testCaseClass, testCaseNames))
+        suite = self.classSuiteClass(map(testCaseClass, testCaseNames),
+                                     testCaseClass)
+        return suite
 
     def loadTestsFromModule(self, module):
         """Return a suite of all tests cases contained in the given module"""
@@ -739,6 +878,30 @@
             self.stream.write('F')
             self.stream.flush()
 
+    def addSkip(self, test, reason):
+        TestResult.addSkip(self, test, reason)
+        if self.showAll:
+            self.stream.writeln("skipped {0!r}".format(reason))
+        elif self.dots:
+            self.stream.write("s")
+            self.stream.flush()
+
+    def addExpectedFailure(self, test, err):
+        TestResult.addExpectedFailure(self, test, err)
+        if self.showAll:
+            self.stream.writeln("expected failure")
+        elif self.dots:
+            self.stream.write(".")
+            self.stream.flush()
+
+    def addUnexpectedSuccess(self, test):
+        TestResult.addUnexpectedSuccess(self, test)
+        if self.showAll:
+            self.stream.writeln("unexpected success")
+        elif self.dots:
+            self.stream.write(".")
+            self.stream.flush()
+
     def printErrors(self):
         if self.dots or self.showAll:
             self.stream.writeln()
@@ -780,17 +943,28 @@
         self.stream.writeln("Ran %d test%s in %.3fs" %
                             (run, run != 1 and "s" or "", timeTaken))
         self.stream.writeln()
+        results = map(len, (result.expected_failures,
+                            result.unexpected_successes,
+                            result.skipped))
+        expected_fails, unexpected_successes, skipped = results
+        infos = []
         if not result.wasSuccessful():
-            self.stream.write("FAILED (")
+            self.stream.write("FAILED")
             failed, errored = len(result.failures), len(result.errors)
             if failed:
-                self.stream.write("failures=%d" % failed)
+                infos.append("failures=%d" % failed)
             if errored:
-                if failed: self.stream.write(", ")
-                self.stream.write("errors=%d" % errored)
-            self.stream.writeln(")")
+                infos.append("errors=%d" % errored)
         else:
-            self.stream.writeln("OK")
+            self.stream.write("OK")
+        if skipped:
+            infos.append("skipped=%d" % skipped)
+        if expected_fails:
+            infos.append("expected failures=%d" % expected_fails)
+        if unexpected_successes:
+            infos.append("unexpected successes=%d" % unexpected_successes)
+        if infos:
+            self.stream.writeln(" (%s)" % (", ".join(infos),))
         return result
 
 
@@ -844,9 +1018,9 @@
 
     def parseArgs(self, argv):
         import getopt
+        long_opts = ['help','verbose','quiet']
         try:
-            options, args = getopt.getopt(argv[1:], 'hHvq',
-                                          ['help','verbose','quiet'])
+            options, args = getopt.getopt(argv[1:], 'hHvq', long_opts)
             for opt, value in options:
                 if opt in ('-h','-H','--help'):
                     self.usageExit()