refactor AES keywrap into a wrap core and unwrap core (#3901)

* refactor AES keywrap into a wrap core and unwrap core

This refactor makes adding AES keywrap with padding much simpler.

* remove an unneeded arg
diff --git a/src/cryptography/hazmat/primitives/keywrap.py b/src/cryptography/hazmat/primitives/keywrap.py
index 6e79ab6..702a693 100644
--- a/src/cryptography/hazmat/primitives/keywrap.py
+++ b/src/cryptography/hazmat/primitives/keywrap.py
@@ -12,20 +12,9 @@
 from cryptography.hazmat.primitives.constant_time import bytes_eq
 
 
-def aes_key_wrap(wrapping_key, key_to_wrap, backend):
-    if len(wrapping_key) not in [16, 24, 32]:
-        raise ValueError("The wrapping key must be a valid AES key length")
-
-    if len(key_to_wrap) < 16:
-        raise ValueError("The key to wrap must be at least 16 bytes")
-
-    if len(key_to_wrap) % 8 != 0:
-        raise ValueError("The key to wrap must be a multiple of 8 bytes")
-
+def _wrap_core(wrapping_key, a, r, backend):
     # RFC 3394 Key Wrap - 2.2.1 (index method)
     encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor()
-    a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
-    r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
     n = len(r)
     for j in range(6):
         for i in range(n):
@@ -44,22 +33,24 @@
     return a + b"".join(r)
 
 
-def aes_key_unwrap(wrapping_key, wrapped_key, backend):
-    if len(wrapped_key) < 24:
-        raise ValueError("Must be at least 24 bytes")
-
-    if len(wrapped_key) % 8 != 0:
-        raise ValueError("The wrapped key must be a multiple of 8 bytes")
-
+def aes_key_wrap(wrapping_key, key_to_wrap, backend):
     if len(wrapping_key) not in [16, 24, 32]:
         raise ValueError("The wrapping key must be a valid AES key length")
 
+    if len(key_to_wrap) < 16:
+        raise ValueError("The key to wrap must be at least 16 bytes")
+
+    if len(key_to_wrap) % 8 != 0:
+        raise ValueError("The key to wrap must be a multiple of 8 bytes")
+
+    a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
+    r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
+    return _wrap_core(wrapping_key, a, r, backend)
+
+
+def _unwrap_core(wrapping_key, a, r, backend):
     # Implement RFC 3394 Key Unwrap - 2.2.2 (index method)
     decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor()
-    aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
-
-    r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
-    a = r.pop(0)
     n = len(r)
     for j in reversed(range(6)):
         for i in reversed(range(n)):
@@ -74,7 +65,23 @@
             r[i] = b[-8:]
 
     assert decryptor.finalize() == b""
+    return a, r
 
+
+def aes_key_unwrap(wrapping_key, wrapped_key, backend):
+    if len(wrapped_key) < 24:
+        raise ValueError("Must be at least 24 bytes")
+
+    if len(wrapped_key) % 8 != 0:
+        raise ValueError("The wrapped key must be a multiple of 8 bytes")
+
+    if len(wrapping_key) not in [16, 24, 32]:
+        raise ValueError("The wrapping key must be a valid AES key length")
+
+    aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
+    r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
+    a = r.pop(0)
+    a, r = _unwrap_core(wrapping_key, a, r, backend)
     if not bytes_eq(a, aiv):
         raise InvalidUnwrap()