Fixed #3008 -- expose calculate max pss salt length	 (#3014)

* Fixed #3008 -- expose calculate max pss salt length

* Fixed a few mistakes in the docs

* move all the code around

* oops

* write a unit test

* versionadded + changelog
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 65dfebe..2853085 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -6,6 +6,9 @@
 
 .. note:: This version is not yet released and is under active development.
 
+* Added
+  :func:`~cryptography.hazmat.primitives.asymmetric.padding.calculate_max_pss_salt_length`.
+
 
 1.4 - 2016-06-04
 ~~~~~~~~~~~~~~~~
diff --git a/docs/hazmat/primitives/asymmetric/rsa.rst b/docs/hazmat/primitives/asymmetric/rsa.rst
index 369f857..9321444 100644
--- a/docs/hazmat/primitives/asymmetric/rsa.rst
+++ b/docs/hazmat/primitives/asymmetric/rsa.rst
@@ -330,6 +330,20 @@
     :class:`OAEP` should be preferred for encryption and :class:`PSS` should be
     preferred for signatures.
 
+
+.. function:: calculate_max_pss_salt_length(key, hash_algorithm)
+
+    .. versionadded:: 1.5
+
+    :param key: An RSA public or private key.
+    :param hash_algorithm: A
+        :class:`cryptography.hazmat.primitives.hashes.HashAlgorithm`.
+    :returns int: The computed salt length.
+
+    Computes the length of the salt that :class:`PSS` will use if
+    :data:`PSS.MAX_LENGTH` is used.
+
+
 Mask generation functions
 -------------------------
 
@@ -341,11 +355,10 @@
         Removed the deprecated ``salt_length`` parameter.
 
     MGF1 (Mask Generation Function 1) is used as the mask generation function
-    in :class:`PSS` padding. It takes a hash algorithm and a salt length.
+    in :class:`PSS` and :class:`OAEP` padding. It takes a hash algorithm.
 
-    :param algorithm: An instance of a
-        :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`
-        provider.
+    :param algorithm: An instance of
+        :class:`~cryptography.hazmat.primitives.hashes.HashAlgorithm`.
 
 Numbers
 ~~~~~~~
diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py
index 63ba627..a85f7da 100644
--- a/src/cryptography/hazmat/backends/openssl/rsa.py
+++ b/src/cryptography/hazmat/backends/openssl/rsa.py
@@ -15,22 +15,18 @@
     AsymmetricSignatureContext, AsymmetricVerificationContext, rsa
 )
 from cryptography.hazmat.primitives.asymmetric.padding import (
-    AsymmetricPadding, MGF1, OAEP, PKCS1v15, PSS
+    AsymmetricPadding, MGF1, OAEP, PKCS1v15, PSS, calculate_max_pss_salt_length
 )
 from cryptography.hazmat.primitives.asymmetric.rsa import (
     RSAPrivateKeyWithSerialization, RSAPublicKeyWithSerialization
 )
 
 
-def _get_rsa_pss_salt_length(pss, key_size, digest_size):
+def _get_rsa_pss_salt_length(pss, key, hash_algorithm):
     salt = pss._salt_length
 
     if salt is MGF1.MAX_LENGTH or salt is PSS.MAX_LENGTH:
-        # bit length - 1 per RFC 3447
-        emlen = int(math.ceil((key_size - 1) / 8.0))
-        salt_length = emlen - digest_size - 2
-        assert salt_length >= 0
-        return salt_length
+        return calculate_max_pss_salt_length(key, hash_algorithm)
     else:
         return salt
 
@@ -220,8 +216,8 @@
                 pkey_ctx,
                 _get_rsa_pss_salt_length(
                     self._padding,
-                    self._private_key.key_size,
-                    self._hash_ctx.algorithm.digest_size
+                    self._private_key,
+                    self._hash_ctx.algorithm,
                 )
             )
             self._backend.openssl_assert(res > 0)
@@ -348,8 +344,8 @@
                 pkey_ctx,
                 _get_rsa_pss_salt_length(
                     self._padding,
-                    self._public_key.key_size,
-                    self._hash_ctx.algorithm.digest_size
+                    self._public_key,
+                    self._hash_ctx.algorithm,
                 )
             )
             self._backend.openssl_assert(res > 0)
diff --git a/src/cryptography/hazmat/primitives/asymmetric/padding.py b/src/cryptography/hazmat/primitives/asymmetric/padding.py
index c796d8e..a37c3f9 100644
--- a/src/cryptography/hazmat/primitives/asymmetric/padding.py
+++ b/src/cryptography/hazmat/primitives/asymmetric/padding.py
@@ -5,11 +5,13 @@
 from __future__ import absolute_import, division, print_function
 
 import abc
+import math
 
 import six
 
 from cryptography import utils
 from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import rsa
 
 
 @six.add_metaclass(abc.ABCMeta)
@@ -65,3 +67,13 @@
             raise TypeError("Expected instance of hashes.HashAlgorithm.")
 
         self._algorithm = algorithm
+
+
+def calculate_max_pss_salt_length(key, hash_algorithm):
+    if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
+        raise TypeError("key must be an RSA public or private key")
+    # bit length - 1 per RFC 3447
+    emlen = int(math.ceil((key.key_size - 1) / 8.0))
+    salt_length = emlen - hash_algorithm.digest_size - 2
+    assert salt_length >= 0
+    return salt_length
diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py
index 9f3008e..e4e4378 100644
--- a/tests/hazmat/primitives/test_rsa.py
+++ b/tests/hazmat/primitives/test_rsa.py
@@ -1081,6 +1081,10 @@
 
 
 class TestPSS(object):
+    def test_calculate_max_pss_salt_length(self):
+        with pytest.raises(TypeError):
+            padding.calculate_max_pss_salt_length(object(), hashes.SHA256())
+
     def test_invalid_salt_length_not_integer(self):
         with pytest.raises(TypeError):
             padding.PSS(