Backport of decimal module context management updates from rev 51694 to 2.5 release branch
diff --git a/Lib/decimal.py b/Lib/decimal.py
index 396c413..a66beef 100644
--- a/Lib/decimal.py
+++ b/Lib/decimal.py
@@ -131,7 +131,7 @@
'ROUND_FLOOR', 'ROUND_UP', 'ROUND_HALF_DOWN',
# Functions for manipulating contexts
- 'setcontext', 'getcontext'
+ 'setcontext', 'getcontext', 'localcontext'
]
import copy as _copy
@@ -458,6 +458,49 @@
del threading, local # Don't contaminate the namespace
+def localcontext(ctx=None):
+ """Return a context manager for a copy of the supplied context
+
+ Uses a copy of the current context if no context is specified
+ The returned context manager creates a local decimal context
+ in a with statement:
+ def sin(x):
+ with localcontext() as ctx:
+ ctx.prec += 2
+ # Rest of sin calculation algorithm
+ # uses a precision 2 greater than normal
+ return +s # Convert result to normal precision
+
+ def sin(x):
+ with localcontext(ExtendedContext):
+ # Rest of sin calculation algorithm
+ # uses the Extended Context from the
+ # General Decimal Arithmetic Specification
+ return +s # Convert result to normal context
+
+ """
+ # The below can't be included in the docstring until Python 2.6
+ # as the doctest module doesn't understand __future__ statements
+ """
+ >>> from __future__ import with_statement
+ >>> print getcontext().prec
+ 28
+ >>> with localcontext():
+ ... ctx = getcontext()
+ ... ctx.prec() += 2
+ ... print ctx.prec
+ ...
+ 30
+ >>> with localcontext(ExtendedContext):
+ ... print getcontext().prec
+ ...
+ 9
+ >>> print getcontext().prec
+ 28
+ """
+ if ctx is None: ctx = getcontext()
+ return _ContextManager(ctx)
+
##### Decimal class ###########################################
@@ -2173,23 +2216,14 @@
del name, val, globalname, rounding_functions
-class ContextManager(object):
- """Helper class to simplify Context management.
+class _ContextManager(object):
+ """Context manager class to support localcontext().
- Sample usage:
-
- with decimal.ExtendedContext:
- s = ...
- return +s # Convert result to normal precision
-
- with decimal.getcontext() as ctx:
- ctx.prec += 2
- s = ...
- return +s
-
+ Sets a copy of the supplied context in __enter__() and restores
+ the previous decimal context in __exit__()
"""
def __init__(self, new_context):
- self.new_context = new_context
+ self.new_context = new_context.copy()
def __enter__(self):
self.saved_context = getcontext()
setcontext(self.new_context)
@@ -2248,9 +2282,6 @@
s.append('traps=[' + ', '.join([t.__name__ for t, v in self.traps.items() if v]) + ']')
return ', '.join(s) + ')'
- def get_manager(self):
- return ContextManager(self.copy())
-
def clear_flags(self):
"""Reset all flags to zero"""
for flag in self.flags:
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 2cf39ae..747785d 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -330,32 +330,6 @@
return True
self.boilerPlate(lock, locked)
-class DecimalContextTestCase(unittest.TestCase):
-
- # XXX Somebody should write more thorough tests for this
-
- def testBasic(self):
- ctx = decimal.getcontext()
- orig_context = ctx.copy()
- try:
- ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
- with decimal.ExtendedContext.get_manager():
- self.assertEqual(decimal.getcontext().prec,
- decimal.ExtendedContext.prec)
- self.assertEqual(decimal.getcontext().prec, save_prec)
- try:
- with decimal.ExtendedContext.get_manager():
- self.assertEqual(decimal.getcontext().prec,
- decimal.ExtendedContext.prec)
- 1/0
- except ZeroDivisionError:
- self.assertEqual(decimal.getcontext().prec, save_prec)
- else:
- self.fail("Didn't raise ZeroDivisionError")
- finally:
- decimal.setcontext(orig_context)
-
-
# This is needed to make the test actually run under regrtest.py!
def test_main():
run_suite(
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index f3f9215..841ea6f 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -23,6 +23,7 @@
you're working through IDLE, you can import this test module and call test_main()
with the corresponding argument.
"""
+from __future__ import with_statement
import unittest
import glob
@@ -1064,6 +1065,32 @@
self.assertNotEqual(id(c.flags), id(d.flags))
self.assertNotEqual(id(c.traps), id(d.traps))
+class WithStatementTest(unittest.TestCase):
+ # Can't do these as docstrings until Python 2.6
+ # as doctest can't handle __future__ statements
+
+ def test_localcontext(self):
+ # Use a copy of the current context in the block
+ orig_ctx = getcontext()
+ with localcontext() as enter_ctx:
+ set_ctx = getcontext()
+ final_ctx = getcontext()
+ self.assert_(orig_ctx is final_ctx, 'did not restore context correctly')
+ self.assert_(orig_ctx is not set_ctx, 'did not copy the context')
+ self.assert_(set_ctx is enter_ctx, '__enter__ returned wrong context')
+
+ def test_localcontextarg(self):
+ # Use a copy of the supplied context in the block
+ orig_ctx = getcontext()
+ new_ctx = Context(prec=42)
+ with localcontext(new_ctx) as enter_ctx:
+ set_ctx = getcontext()
+ final_ctx = getcontext()
+ self.assert_(orig_ctx is final_ctx, 'did not restore context correctly')
+ self.assert_(set_ctx.prec == new_ctx.prec, 'did not set correct context')
+ self.assert_(new_ctx is not set_ctx, 'did not copy the context')
+ self.assert_(set_ctx is enter_ctx, '__enter__ returned wrong context')
+
def test_main(arith=False, verbose=None):
""" Execute the tests.
@@ -1084,6 +1111,7 @@
DecimalPythonAPItests,
ContextAPItests,
DecimalTest,
+ WithStatementTest,
]
try: