move tag_length to the AESCCM constructor (#3783)

* move tag_length to the AESCCM constructor

* review feedback
diff --git a/docs/hazmat/primitives/aead.rst b/docs/hazmat/primitives/aead.rst
index 94b08f0..6b13edc 100644
--- a/docs/hazmat/primitives/aead.rst
+++ b/docs/hazmat/primitives/aead.rst
@@ -78,7 +78,7 @@
             when the ciphertext has been changed, but will also occur when the
             key, nonce, or associated data are wrong.
 
-.. class:: AESCCM(key)
+.. class:: AESCCM(key, tag_length=16)
 
     .. versionadded:: 2.0
 
@@ -93,6 +93,10 @@
     cipher utilizing Counter with CBC-MAC (CCM) (specified in :rfc:`3610`).
 
     :param bytes key: A 128, 192, or 256-bit key. This **must** be kept secret.
+    :param int tag_length: The length of the authentication tag. This
+        defaults to 16 bytes and it is **strongly** recommended that you
+        do not make it shorter unless absolutely necessary. Valid tag
+        lengths are 4, 6, 8, 12, 14, and 16.
 
     :raises cryptography.exceptions.UnsupportedAlgorithm: If the version of
         OpenSSL does not support AES-CCM.
@@ -119,7 +123,7 @@
 
         :returns bytes: The generated key.
 
-    .. method:: encrypt(nonce, data, associated_data, tag_length=16)
+    .. method:: encrypt(nonce, data, associated_data)
 
         .. warning::
 
@@ -138,13 +142,9 @@
         :param bytes data: The data to encrypt.
         :param bytes associated_data: Additional data that should be
             authenticated with the key, but is not encrypted. Can be ``None``.
-        :param int tag_length: The length of the authentication tag. This
-            defaults to 16 bytes and it is **strongly** recommended that you
-            do not make it shorter unless absolutely necessary. Valid tag
-            lengths are 4, 6, 8, 12, 14, and 16.
         :returns bytes: The ciphertext bytes with the tag appended.
 
-    .. method:: decrypt(nonce, data, associated_data, tag_length=16)
+    .. method:: decrypt(nonce, data, associated_data)
 
         Decrypts the ``data`` and authenticates the ``associated_data``. If you
         called encrypt with ``associated_data`` you must pass the same
@@ -156,10 +156,6 @@
         :param bytes data: The data to decrypt (with tag appended).
         :param bytes associated_data: Additional data to authenticate. Can be
             ``None`` if none was passed during encryption.
-        :param int tag_length: The length of the authentication tag. This
-            defaults to 16 bytes. You only need to change this if your existing
-            ciphertext has a shorter tag. Valid tag lengths are 4, 6, 8, 12,
-            14, and 16.
         :returns bytes: The original plaintext.
         :raises cryptography.exceptions.InvalidTag: If the authentication tag
             doesn't validate this exception will be raised. This will occur
diff --git a/src/cryptography/hazmat/primitives/ciphers/aead.py b/src/cryptography/hazmat/primitives/ciphers/aead.py
index 189cb5b..e2c5e38 100644
--- a/src/cryptography/hazmat/primitives/ciphers/aead.py
+++ b/src/cryptography/hazmat/primitives/ciphers/aead.py
@@ -56,12 +56,20 @@
 
 
 class AESCCM(object):
-    def __init__(self, key):
+    def __init__(self, key, tag_length=16):
         utils._check_bytes("key", key)
         if len(key) not in (16, 24, 32):
             raise ValueError("AESCCM key must be 128, 192, or 256 bits.")
 
         self._key = key
+        if not isinstance(tag_length, int):
+            raise TypeError("tag_length must be an integer")
+
+        if tag_length not in (4, 6, 8, 12, 14, 16):
+            raise ValueError("Invalid tag_length")
+
+        self._tag_length = tag_length
+
         if not backend.aead_cipher_supported(self):
             raise exceptions.UnsupportedAlgorithm(
                 "AESCCM is not supported by this version of OpenSSL",
@@ -78,23 +86,23 @@
 
         return os.urandom(bit_length // 8)
 
-    def encrypt(self, nonce, data, associated_data, tag_length=16):
+    def encrypt(self, nonce, data, associated_data):
         if associated_data is None:
             associated_data = b""
 
-        self._check_params(nonce, data, associated_data, tag_length)
+        self._check_params(nonce, data, associated_data)
         self._validate_lengths(nonce, len(data))
         return aead._encrypt(
-            backend, self, nonce, data, associated_data, tag_length
+            backend, self, nonce, data, associated_data, self._tag_length
         )
 
-    def decrypt(self, nonce, data, associated_data, tag_length=16):
+    def decrypt(self, nonce, data, associated_data):
         if associated_data is None:
             associated_data = b""
 
-        self._check_params(nonce, data, associated_data, tag_length)
+        self._check_params(nonce, data, associated_data)
         return aead._decrypt(
-            backend, self, nonce, data, associated_data, tag_length
+            backend, self, nonce, data, associated_data, self._tag_length
         )
 
     def _validate_lengths(self, nonce, data_len):
@@ -104,13 +112,7 @@
         if 2 ** (8 * l) < data_len:
             raise ValueError("Nonce too long for data")
 
-    def _check_params(self, nonce, data, associated_data, tag_length):
-        if not isinstance(tag_length, int):
-            raise TypeError("tag_length must be an integer")
-
-        if tag_length not in (4, 6, 8, 12, 14, 16):
-            raise ValueError("Invalid tag_length")
-
+    def _check_params(self, nonce, data, associated_data):
         utils._check_bytes("nonce", nonce)
         utils._check_bytes("data", data)
         utils._check_bytes("associated_data", associated_data)
diff --git a/tests/hazmat/primitives/test_aead.py b/tests/hazmat/primitives/test_aead.py
index 9700a1a..27374da 100644
--- a/tests/hazmat/primitives/test_aead.py
+++ b/tests/hazmat/primitives/test_aead.py
@@ -178,14 +178,14 @@
 
     def test_invalid_tag_length(self, backend):
         key = AESCCM.generate_key(128)
-        aesccm = AESCCM(key)
-        pt = b"hello"
-        nonce = os.urandom(12)
         with pytest.raises(ValueError):
-            aesccm.encrypt(nonce, pt, None, tag_length=7)
+            AESCCM(key, tag_length=7)
 
         with pytest.raises(ValueError):
-            aesccm.encrypt(nonce, pt, None, tag_length=2)
+            AESCCM(key, tag_length=2)
+
+        with pytest.raises(TypeError):
+            AESCCM(key, tag_length="notanint")
 
     def test_invalid_nonce_length(self, backend):
         key = AESCCM.generate_key(128)
@@ -217,14 +217,14 @@
         adata = binascii.unhexlify(vector["adata"])[:vector["alen"]]
         ct = binascii.unhexlify(vector["ct"])
         pt = binascii.unhexlify(vector["payload"])[:vector["plen"]]
-        aesccm = AESCCM(key)
+        aesccm = AESCCM(key, vector["tlen"])
         if vector.get('fail'):
             with pytest.raises(InvalidTag):
-                aesccm.decrypt(nonce, ct, adata, vector["tlen"])
+                aesccm.decrypt(nonce, ct, adata)
         else:
-            computed_pt = aesccm.decrypt(nonce, ct, adata, vector["tlen"])
+            computed_pt = aesccm.decrypt(nonce, ct, adata)
             assert computed_pt == pt
-            assert aesccm.encrypt(nonce, pt, adata, vector["tlen"]) == ct
+            assert aesccm.encrypt(nonce, pt, adata) == ct
 
     def test_roundtrip(self, backend):
         key = AESCCM.generate_key(128)
@@ -232,8 +232,8 @@
         pt = b"encrypt me"
         ad = b"additional"
         nonce = os.urandom(12)
-        ct = aesccm.encrypt(nonce, pt, ad, 16)
-        computed_pt = aesccm.decrypt(nonce, ct, ad, 16)
+        ct = aesccm.encrypt(nonce, pt, ad)
+        computed_pt = aesccm.decrypt(nonce, ct, ad)
         assert computed_pt == pt
 
     def test_nonce_too_long(self, backend):
@@ -243,23 +243,21 @@
         # pt can be no more than 65536 bytes when nonce is 13 bytes
         nonce = os.urandom(13)
         with pytest.raises(ValueError):
-            aesccm.encrypt(nonce, pt, None, 16)
+            aesccm.encrypt(nonce, pt, None)
 
     @pytest.mark.parametrize(
-        ("nonce", "data", "associated_data", "tag_length"),
+        ("nonce", "data", "associated_data"),
         [
-            [object(), b"data", b"", 16],
-            [b"0" * 12, object(), b"", 16],
-            [b"0" * 12, b"data", object(), 16],
-            [b"0" * 12, b"data", b"", object()]
+            [object(), b"data", b""],
+            [b"0" * 12, object(), b""],
+            [b"0" * 12, b"data", object()],
         ]
     )
-    def test_params_not_bytes(self, nonce, data, associated_data, tag_length,
-                              backend):
+    def test_params_not_bytes(self, nonce, data, associated_data, backend):
         key = AESCCM.generate_key(128)
         aesccm = AESCCM(key)
         with pytest.raises(TypeError):
-            aesccm.encrypt(nonce, data, associated_data, tag_length)
+            aesccm.encrypt(nonce, data, associated_data)
 
     def test_bad_key(self, backend):
         with pytest.raises(TypeError):