Merge pull request #234 from alex/padding-fixes

Made PKCS7 unpadding more constant time
diff --git a/cryptography/hazmat/primitives/padding.py b/cryptography/hazmat/primitives/padding.py
index 2dbac75..cfa90db 100644
--- a/cryptography/hazmat/primitives/padding.py
+++ b/cryptography/hazmat/primitives/padding.py
@@ -11,12 +11,58 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import cffi
+
 import six
 
 from cryptography import utils
 from cryptography.hazmat.primitives import interfaces
 
 
+_ffi = cffi.FFI()
+_ffi.cdef("""
+bool Cryptography_check_pkcs7_padding(const uint8_t *, uint8_t);
+""")
+_lib = _ffi.verify("""
+#include <stdbool.h>
+
+/* Returns the value of the input with the most-significant-bit copied to all
+   of the bits. */
+static uint8_t Cryptography_DUPLICATE_MSB_TO_ALL(uint8_t a) {
+    return (1 - (a >> (sizeof(uint8_t) * 8 - 1))) - 1;
+}
+
+/* This returns 0xFF if a < b else 0x00, but does so in a constant time
+   fashion */
+static uint8_t Cryptography_constant_time_lt(uint8_t a, uint8_t b) {
+    a -= b;
+    return Cryptography_DUPLICATE_MSB_TO_ALL(a);
+}
+
+bool Cryptography_check_pkcs7_padding(const uint8_t *data, uint8_t block_len) {
+    uint8_t i;
+    uint8_t pad_size = data[block_len - 1];
+    uint8_t mismatch = 0;
+    for (i = 0; i < block_len; i++) {
+        unsigned int mask = Cryptography_constant_time_lt(i, pad_size);
+        uint8_t b = data[block_len - 1 - i];
+        mismatch |= (mask & (pad_size ^ b));
+    }
+
+    /* Check to make sure the pad_size was within the valid range. */
+    mismatch |= ~Cryptography_constant_time_lt(0, pad_size);
+    mismatch |= Cryptography_constant_time_lt(block_len, pad_size);
+
+    /* Make sure any bits set are copied to the lowest bit */
+    mismatch |= mismatch >> 4;
+    mismatch |= mismatch >> 2;
+    mismatch |= mismatch >> 1;
+    /* Now check the low bit to see if it's set */
+    return (mismatch & 1) == 0;
+}
+""")
+
+
 class PKCS7(object):
     def __init__(self, block_size):
         if not (0 <= block_size < 256):
@@ -102,18 +148,14 @@
         if len(self._buffer) != self.block_size // 8:
             raise ValueError("Invalid padding bytes")
 
+        valid = _lib.Cryptography_check_pkcs7_padding(
+            self._buffer, self.block_size // 8
+        )
+
+        if not valid:
+            raise ValueError("Invalid padding bytes")
+
         pad_size = six.indexbytes(self._buffer, -1)
-
-        if not (0 < pad_size <= self.block_size // 8):
-            raise ValueError("Invalid padding bytes")
-
-        mismatch = 0
-        for b in six.iterbytes(self._buffer[-pad_size:]):
-            mismatch |= b ^ pad_size
-
-        if mismatch != 0:
-            raise ValueError("Invalid padding bytes")
-
         res = self._buffer[:-pad_size]
         self._buffer = None
         return res