No longer require string operations to find bit lengths.
Now that we no longer support Python 2.6, we can use int.bit_length()
instead. Co-authored by @adamantike.
diff --git a/rsa/common.py b/rsa/common.py
index 34142cc..4f8c0d9 100644
--- a/rsa/common.py
+++ b/rsa/common.py
@@ -31,9 +31,6 @@
Number of bits needed to represent a integer excluding any prefix
0 bits.
- As per definition from https://wiki.python.org/moin/BitManipulation and
- to match the behavior of the Python 3 API.
-
Usage::
>>> bit_size(1023)
@@ -50,41 +47,11 @@
:returns:
Returns the number of bits in the integer.
"""
- if num == 0:
- return 0
- if num < 0:
- num = -num
- # Make sure this is an int and not a float.
- num & 1
-
- hex_num = "%x" % num
- return ((len(hex_num) - 1) * 4) + {
- '0': 0, '1': 1, '2': 2, '3': 2,
- '4': 3, '5': 3, '6': 3, '7': 3,
- '8': 4, '9': 4, 'a': 4, 'b': 4,
- 'c': 4, 'd': 4, 'e': 4, 'f': 4,
- }[hex_num[0]]
-
-
-def _bit_size(number):
- """
- Returns the number of bits required to hold a specific long number.
- """
- if number < 0:
- raise ValueError('Only nonnegative numbers possible: %s' % number)
-
- if number == 0:
- return 0
-
- # This works, even with very large numbers. When using math.log(number, 2),
- # you'll get rounding errors and it'll fail.
- bits = 0
- while number:
- bits += 1
- number >>= 1
-
- return bits
+ try:
+ return num.bit_length()
+ except AttributeError:
+ raise TypeError('bit_size(num) only supports integers, not %r' % type(num))
def byte_size(number):
diff --git a/tests/test_common.py b/tests/test_common.py
index ef32f61..e26e004 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -18,7 +18,7 @@
import unittest
import struct
from rsa._compat import byte, b
-from rsa.common import byte_size, bit_size, _bit_size, inverse
+from rsa.common import byte_size, bit_size, inverse
class TestByte(unittest.TestCase):
@@ -69,12 +69,21 @@
self.assertEqual(bit_size((1 << 1024) + 1), 1025)
self.assertEqual(bit_size((1 << 1024) - 1), 1024)
- self.assertEqual(_bit_size(1023), 10)
- self.assertEqual(_bit_size(1024), 11)
- self.assertEqual(_bit_size(1025), 11)
- self.assertEqual(_bit_size(1 << 1024), 1025)
- self.assertEqual(_bit_size((1 << 1024) + 1), 1025)
- self.assertEqual(_bit_size((1 << 1024) - 1), 1024)
+ def test_negative_values(self):
+ self.assertEqual(bit_size(-1023), 10)
+ self.assertEqual(bit_size(-1024), 11)
+ self.assertEqual(bit_size(-1025), 11)
+ self.assertEqual(bit_size(-1 << 1024), 1025)
+ self.assertEqual(bit_size(-((1 << 1024) + 1)), 1025)
+ self.assertEqual(bit_size(-((1 << 1024) - 1)), 1024)
+
+ def test_bad_type(self):
+ self.assertRaises(TypeError, bit_size, [])
+ self.assertRaises(TypeError, bit_size, ())
+ self.assertRaises(TypeError, bit_size, dict())
+ self.assertRaises(TypeError, bit_size, "")
+ self.assertRaises(TypeError, bit_size, None)
+ self.assertRaises(TypeError, bit_size, 0.0)
class TestInverse(unittest.TestCase):