Merged revisions 80476 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r80476 | michael.foord | 2010-04-25 20:02:46 +0100 (Sun, 25 Apr 2010) | 1 line

  Adding unittest.removeHandler function / decorator for removing the signal.SIGINT signal handler. With tests and docs.
........
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py
index e84299e..201a3f0 100644
--- a/Lib/unittest/__init__.py
+++ b/Lib/unittest/__init__.py
@@ -48,7 +48,7 @@
            'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main',
            'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
            'expectedFailure', 'TextTestResult', 'installHandler',
-           'registerResult', 'removeResult']
+           'registerResult', 'removeResult', 'removeHandler']
 
 # Expose obsolete functions for backwards compatibility
 __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases'])
@@ -63,7 +63,7 @@
                      findTestCases)
 from .main import TestProgram, main
 from .runner import TextTestRunner, TextTestResult
-from .signals import installHandler, registerResult, removeResult
+from .signals import installHandler, registerResult, removeResult, removeHandler
 
 # deprecated
 _TextTestResult = TextTestResult
diff --git a/Lib/unittest/signals.py b/Lib/unittest/signals.py
index 0651cf2..fc31043 100644
--- a/Lib/unittest/signals.py
+++ b/Lib/unittest/signals.py
@@ -1,6 +1,8 @@
 import signal
 import weakref
 
+from functools import wraps
+
 __unittest = True
 
 
@@ -36,3 +38,20 @@
         default_handler = signal.getsignal(signal.SIGINT)
         _interrupt_handler = _InterruptHandler(default_handler)
         signal.signal(signal.SIGINT, _interrupt_handler)
+
+
+def removeHandler(method=None):
+    if method is not None:
+        @wraps(method)
+        def inner(*args, **kwargs):
+            initial = signal.getsignal(signal.SIGINT)
+            removeHandler()
+            try:
+                return method(*args, **kwargs)
+            finally:
+                signal.signal(signal.SIGINT, initial)
+        return inner
+
+    global _interrupt_handler
+    if _interrupt_handler is not None:
+        signal.signal(signal.SIGINT, _interrupt_handler.default_handler)
diff --git a/Lib/unittest/test/test_break.py b/Lib/unittest/test/test_break.py
index 4f89e87..0e09dfb 100644
--- a/Lib/unittest/test/test_break.py
+++ b/Lib/unittest/test/test_break.py
@@ -227,3 +227,24 @@
         self.assertEqual(p.result, result)
 
         self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler)
+
+    def testRemoveHandler(self):
+        default_handler = signal.getsignal(signal.SIGINT)
+        unittest.installHandler()
+        unittest.removeHandler()
+        self.assertEqual(signal.getsignal(signal.SIGINT), default_handler)
+
+        # check that calling removeHandler multiple times has no ill-effect
+        unittest.removeHandler()
+        self.assertEqual(signal.getsignal(signal.SIGINT), default_handler)
+
+    def testRemoveHandlerAsDecorator(self):
+        default_handler = signal.getsignal(signal.SIGINT)
+        unittest.installHandler()
+
+        @unittest.removeHandler
+        def test():
+            self.assertEqual(signal.getsignal(signal.SIGINT), default_handler)
+
+        test()
+        self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler)