feat: encrypted mtls private key support (#496)
(1) support encrypted private key decryption
(2) support reading encrypted key and passphrase from cert provider
command
Co-authored-by: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com>
diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py
index 5b61460..da06d86 100644
--- a/google/auth/exceptions.py
+++ b/google/auth/exceptions.py
@@ -39,3 +39,7 @@
class MutualTLSChannelError(GoogleAuthError):
"""Used to indicate that mutual TLS channel creation is failed, or mutual
TLS channel credentials is missing or invalid."""
+
+
+class ClientCertError(GoogleAuthError):
+ """Used to indicate that client certificate is missing or invalid."""
diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py
index c518cc8..388ae3c 100644
--- a/google/auth/transport/_mtls_helper.py
+++ b/google/auth/transport/_mtls_helper.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Helper functions for getting mTLS cert and key, for internal use only."""
+"""Helper functions for getting mTLS cert and key."""
import json
import logging
@@ -20,6 +20,10 @@
import re
import subprocess
+import six
+
+from google.auth import exceptions
+
CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
_CERT_PROVIDER_COMMAND = "cert_provider_command"
_CERT_REGEX = re.compile(
@@ -30,6 +34,7 @@
# "-----BEGIN PRIVATE KEY-----...",
# "-----BEGIN EC PRIVATE KEY-----...",
# "-----BEGIN RSA PRIVATE KEY-----..."
+# "-----BEGIN ENCRYPTED PRIVATE KEY-----"
_KEY_REGEX = re.compile(
b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?",
re.DOTALL,
@@ -38,6 +43,11 @@
_LOGGER = logging.getLogger(__name__)
+_PASSPHRASE_REGEX = re.compile(
+ b"-----BEGIN PASSPHRASE-----(.+)-----END PASSPHRASE-----", re.DOTALL
+)
+
+
def _check_dca_metadata_path(metadata_path):
"""Checks for context aware metadata. If it exists, returns the absolute path;
otherwise returns None.
@@ -65,57 +75,109 @@
Dict[str, str]: The metadata.
Raises:
- ValueError: If failed to parse metadata as JSON.
+ google.auth.exceptions.ClientCertError: If failed to parse metadata as JSON.
"""
- with open(metadata_path) as f:
- metadata = json.load(f)
+ try:
+ with open(metadata_path) as f:
+ metadata = json.load(f)
+ except ValueError as caught_exc:
+ new_exc = exceptions.ClientCertError(caught_exc)
+ six.raise_from(new_exc, caught_exc)
return metadata
-def get_client_ssl_credentials(metadata_json):
- """Returns the client side mTLS cert and key.
+def _run_cert_provider_command(command, expect_encrypted_key=False):
+ """Run the provided command, and return client side mTLS cert, key and
+ passphrase.
Args:
- metadata_json (Dict[str, str]): metadata JSON file which contains the cert
- provider command.
+ command (List[str]): cert provider command.
+ expect_encrypted_key (bool): If encrypted private key is expected.
Returns:
- Tuple[bytes, bytes]: client certificate and key, both in PEM format.
+ Tuple[bytes, bytes, bytes]: client certificate bytes in PEM format, key
+ bytes in PEM format and passphrase bytes.
Raises:
- OSError: If the cert provider command failed to run.
- RuntimeError: If the cert provider command has a runtime error.
- ValueError: If the metadata json file doesn't contain the cert provider
- command or if the command doesn't produce both the client certificate
- and client key.
+ google.auth.exceptions.ClientCertError: if problems occurs when running
+ the cert provider command or generating cert, key and passphrase.
"""
- # TODO: implement an in-memory cache of cert and key so we don't have to
- # run cert provider command every time.
-
- # Check the cert provider command existence in the metadata json file.
- if _CERT_PROVIDER_COMMAND not in metadata_json:
- raise ValueError("Cert provider command is not found")
-
- # Execute the command. It throws OsError in case of system failure.
- command = metadata_json[_CERT_PROVIDER_COMMAND]
- process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- stdout, stderr = process.communicate()
+ try:
+ process = subprocess.Popen(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ stdout, stderr = process.communicate()
+ except OSError as caught_exc:
+ new_exc = exceptions.ClientCertError(caught_exc)
+ six.raise_from(new_exc, caught_exc)
# Check cert provider command execution error.
if process.returncode != 0:
- raise RuntimeError(
+ raise exceptions.ClientCertError(
"Cert provider command returns non-zero status code %s" % process.returncode
)
- # Extract certificate (chain) and key.
+ # Extract certificate (chain), key and passphrase.
cert_match = re.findall(_CERT_REGEX, stdout)
if len(cert_match) != 1:
- raise ValueError("Client SSL certificate is missing or invalid")
+ raise exceptions.ClientCertError("Client SSL certificate is missing or invalid")
key_match = re.findall(_KEY_REGEX, stdout)
if len(key_match) != 1:
- raise ValueError("Client SSL key is missing or invalid")
- return cert_match[0], key_match[0]
+ raise exceptions.ClientCertError("Client SSL key is missing or invalid")
+ passphrase_match = re.findall(_PASSPHRASE_REGEX, stdout)
+
+ if expect_encrypted_key:
+ if len(passphrase_match) != 1:
+ raise exceptions.ClientCertError("Passphrase is missing or invalid")
+ if b"ENCRYPTED" not in key_match[0]:
+ raise exceptions.ClientCertError("Encrypted private key is expected")
+ return cert_match[0], key_match[0], passphrase_match[0].strip()
+
+ if b"ENCRYPTED" in key_match[0]:
+ raise exceptions.ClientCertError("Encrypted private key is not expected")
+ if len(passphrase_match) > 0:
+ raise exceptions.ClientCertError("Passphrase is not expected")
+ return cert_match[0], key_match[0], None
+
+
+def get_client_ssl_credentials(generate_encrypted_key=False):
+ """Returns the client side certificate, private key and passphrase.
+
+ Args:
+ generate_encrypted_key (bool): If set to True, encrypted private key
+ and passphrase will be generated; otherwise, unencrypted private key
+ will be generated and passphrase will be None.
+
+ Returns:
+ Tuple[bool, bytes, bytes, bytes]:
+ A boolean indicating if cert, key and passphrase are obtained, the
+ cert bytes and key bytes both in PEM format, and passphrase bytes.
+
+ Raises:
+ google.auth.exceptions.ClientCertError: if problems occurs when getting
+ the cert, key and passphrase.
+ """
+ metadata_path = _check_dca_metadata_path(CONTEXT_AWARE_METADATA_PATH)
+
+ if metadata_path:
+ metadata_json = _read_dca_metadata_file(metadata_path)
+
+ if _CERT_PROVIDER_COMMAND not in metadata_json:
+ raise exceptions.ClientCertError("Cert provider command is not found")
+
+ command = metadata_json[_CERT_PROVIDER_COMMAND]
+
+ if generate_encrypted_key and "--with_passphrase" not in command:
+ command.append("--with_passphrase")
+
+ # Execute the command.
+ cert, key, passphrase = _run_cert_provider_command(
+ command, expect_encrypted_key=generate_encrypted_key
+ )
+ return True, cert, key, passphrase
+
+ return False, None, None, None
def get_client_cert_and_key(client_cert_callback=None):
@@ -135,20 +197,54 @@
and key bytes both in PEM format.
Raises:
- OSError: If the cert provider command failed to run.
- RuntimeError: If the cert provider command has a runtime error.
- ValueError: If the metadata json file doesn't contain the cert provider
- command or if the command doesn't produce both the client certificate
- and client key.
+ google.auth.exceptions.ClientCertError: if problems occurs when getting
+ the cert and key.
"""
if client_cert_callback:
cert, key = client_cert_callback()
return True, cert, key
- metadata_path = _check_dca_metadata_path(CONTEXT_AWARE_METADATA_PATH)
- if metadata_path:
- metadata = _read_dca_metadata_file(metadata_path)
- cert, key = get_client_ssl_credentials(metadata)
- return True, cert, key
+ has_cert, cert, key, _ = get_client_ssl_credentials(generate_encrypted_key=False)
+ return has_cert, cert, key
- return False, None, None
+
+def decrypt_private_key(key, passphrase):
+ """A helper function to decrypt the private key with the given passphrase.
+ google-auth library doesn't support passphrase protected private key for
+ mutual TLS channel. This helper function can be used to decrypt the
+ passphrase protected private key in order to estalish mutual TLS channel.
+
+ For example, if you have a function which produces client cert, passphrase
+ protected private key and passphrase, you can convert it to a client cert
+ callback function accepted by google-auth::
+
+ from google.auth.transport import _mtls_helper
+
+ def your_client_cert_function():
+ return cert, encrypted_key, passphrase
+
+ # callback accepted by google-auth for mutual TLS channel.
+ def client_cert_callback():
+ cert, encrypted_key, passphrase = your_client_cert_function()
+ decrypted_key = _mtls_helper.decrypt_private_key(encrypted_key,
+ passphrase)
+ return cert, decrypted_key
+
+ Args:
+ key (bytes): The private key bytes in PEM format.
+ passphrase (bytes): The passphrase bytes.
+
+ Returns:
+ bytes: The decrypted private key in PEM format.
+
+ Raises:
+ ImportError: If pyOpenSSL is not installed.
+ OpenSSL.crypto.Error: If there is any problem decrypting the private key.
+ """
+ from OpenSSL import crypto
+
+ # First convert encrypted_key_bytes to PKey object
+ pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key, passphrase=passphrase)
+
+ # Then dump the decrypted key bytes
+ return crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)
diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py
index d62c415..13234a3 100644
--- a/google/auth/transport/grpc.py
+++ b/google/auth/transport/grpc.py
@@ -264,13 +264,10 @@
def __init__(self):
# Load client SSL credentials.
- self._context_aware_metadata_path = _mtls_helper._check_dca_metadata_path(
+ metadata_path = _mtls_helper._check_dca_metadata_path(
_mtls_helper.CONTEXT_AWARE_METADATA_PATH
)
- if self._context_aware_metadata_path:
- self._is_mtls = True
- else:
- self._is_mtls = False
+ self._is_mtls = metadata_path is not None
@property
def ssl_credentials(self):
@@ -288,16 +285,13 @@
google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
creation failed for any reason.
"""
- if self._context_aware_metadata_path:
+ if self._is_mtls:
try:
- metadata = _mtls_helper._read_dca_metadata_file(
- self._context_aware_metadata_path
- )
- cert, key = _mtls_helper.get_client_ssl_credentials(metadata)
+ _, cert, key, _ = _mtls_helper.get_client_ssl_credentials()
self._ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
- except (OSError, RuntimeError, ValueError) as caught_exc:
+ except exceptions.ClientCertError as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
six.raise_from(new_exc, caught_exc)
else:
diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py
index cc0e93b..9f55bea 100644
--- a/google/auth/transport/requests.py
+++ b/google/auth/transport/requests.py
@@ -373,11 +373,9 @@
mtls_adapter = _MutualTlsAdapter(cert, key)
self.mount("https://", mtls_adapter)
except (
+ exceptions.ClientCertError,
ImportError,
OpenSSL.crypto.Error,
- OSError,
- RuntimeError,
- ValueError,
) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
six.raise_from(new_exc, caught_exc)
diff --git a/google/auth/transport/urllib3.py b/google/auth/transport/urllib3.py
index 3771d84..3742f1a 100644
--- a/google/auth/transport/urllib3.py
+++ b/google/auth/transport/urllib3.py
@@ -316,11 +316,9 @@
else:
self.http = _make_default_http()
except (
+ exceptions.ClientCertError,
ImportError,
OpenSSL.crypto.Error,
- OSError,
- RuntimeError,
- ValueError,
) as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
six.raise_from(new_exc, caught_exc)
diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py
index 5bf1967..04d0b56 100644
--- a/tests/transport/test__mtls_helper.py
+++ b/tests/transport/test__mtls_helper.py
@@ -16,14 +16,34 @@
import re
import mock
+from OpenSSL import crypto
import pytest
+from google.auth import exceptions
from google.auth.transport import _mtls_helper
CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]}
CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {}
+ENCRYPTED_EC_PRIVATE_KEY = b"""-----BEGIN ENCRYPTED PRIVATE KEY-----
+MIHkME8GCSqGSIb3DQEFDTBCMCkGCSqGSIb3DQEFDDAcBAgl2/yVgs1h3QICCAAw
+DAYIKoZIhvcNAgkFADAVBgkrBgEEAZdVAQIECJk2GRrvxOaJBIGQXIBnMU4wmciT
+uA6yD8q0FxuIzjG7E2S6tc5VRgSbhRB00eBO3jWmO2pBybeQW+zVioDcn50zp2ts
+wYErWC+LCm1Zg3r+EGnT1E1GgNoODbVQ3AEHlKh1CGCYhEovxtn3G+Fjh7xOBrNB
+saVVeDb4tHD4tMkiVVUBrUcTZPndP73CtgyGHYEphasYPzEz3+AU
+-----END ENCRYPTED PRIVATE KEY-----"""
+
+EC_PUBLIC_KEY = b"""-----BEGIN PUBLIC KEY-----
+MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEvCNi1NoDY1oMqPHIgXI8RBbTYGi/
+brEjbre1nSiQW11xRTJbVeETdsuP0EAu2tG3PcRhhwDfeJ8zXREgTBurNw==
+-----END PUBLIC KEY-----"""
+
+PASSPHRASE = b"""-----BEGIN PASSPHRASE-----
+password
+-----END PASSPHRASE-----"""
+PASSPHRASE_VALUE = b"password"
+
def check_cert_and_key(content, expected_cert, expected_key):
success = True
@@ -115,11 +135,11 @@
def test_file_not_json(self):
# read a file which is not json format.
metadata_path = os.path.join(pytest.data_dir, "privatekey.pem")
- with pytest.raises(ValueError):
+ with pytest.raises(exceptions.ClientCertError):
_mtls_helper._read_dca_metadata_file(metadata_path)
-class TestGetClientSslCredentials(object):
+class TestRunCertProviderCommand(object):
def create_mock_process(self, output, error):
# There are two steps to execute a script with subprocess.Popen.
# (1) process = subprocess.Popen([comannds])
@@ -137,9 +157,20 @@
mock_popen.return_value = self.create_mock_process(
pytest.public_cert_bytes + pytest.private_key_bytes, b""
)
- cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ cert, key, passphrase = _mtls_helper._run_cert_provider_command(["command"])
assert cert == pytest.public_cert_bytes
assert key == pytest.private_key_bytes
+ assert passphrase is None
+
+ mock_popen.return_value = self.create_mock_process(
+ pytest.public_cert_bytes + ENCRYPTED_EC_PRIVATE_KEY + PASSPHRASE, b""
+ )
+ cert, key, passphrase = _mtls_helper._run_cert_provider_command(
+ ["command"], expect_encrypted_key=True
+ )
+ assert cert == pytest.public_cert_bytes
+ assert key == ENCRYPTED_EC_PRIVATE_KEY
+ assert passphrase == PASSPHRASE_VALUE
@mock.patch("subprocess.Popen", autospec=True)
def test_success_with_cert_chain(self, mock_popen):
@@ -147,44 +178,185 @@
mock_popen.return_value = self.create_mock_process(
PUBLIC_CERT_CHAIN_BYTES + pytest.private_key_bytes, b""
)
- cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ cert, key, passphrase = _mtls_helper._run_cert_provider_command(["command"])
assert cert == PUBLIC_CERT_CHAIN_BYTES
assert key == pytest.private_key_bytes
+ assert passphrase is None
- def test_missing_cert_provider_command(self):
- with pytest.raises(ValueError):
- assert _mtls_helper.get_client_ssl_credentials(
- CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND
- )
+ mock_popen.return_value = self.create_mock_process(
+ PUBLIC_CERT_CHAIN_BYTES + ENCRYPTED_EC_PRIVATE_KEY + PASSPHRASE, b""
+ )
+ cert, key, passphrase = _mtls_helper._run_cert_provider_command(
+ ["command"], expect_encrypted_key=True
+ )
+ assert cert == PUBLIC_CERT_CHAIN_BYTES
+ assert key == ENCRYPTED_EC_PRIVATE_KEY
+ assert passphrase == PASSPHRASE_VALUE
@mock.patch("subprocess.Popen", autospec=True)
def test_missing_cert(self, mock_popen):
mock_popen.return_value = self.create_mock_process(
pytest.private_key_bytes, b""
)
- with pytest.raises(ValueError):
- assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(["command"])
+
+ mock_popen.return_value = self.create_mock_process(
+ ENCRYPTED_EC_PRIVATE_KEY + PASSPHRASE, b""
+ )
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(
+ ["command"], expect_encrypted_key=True
+ )
@mock.patch("subprocess.Popen", autospec=True)
def test_missing_key(self, mock_popen):
mock_popen.return_value = self.create_mock_process(
pytest.public_cert_bytes, b""
)
- with pytest.raises(ValueError):
- assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(["command"])
+
+ mock_popen.return_value = self.create_mock_process(
+ pytest.public_cert_bytes + PASSPHRASE, b""
+ )
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(
+ ["command"], expect_encrypted_key=True
+ )
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_missing_passphrase(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(
+ pytest.public_cert_bytes + ENCRYPTED_EC_PRIVATE_KEY, b""
+ )
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(
+ ["command"], expect_encrypted_key=True
+ )
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_passphrase_not_expected(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(
+ pytest.public_cert_bytes + pytest.private_key_bytes + PASSPHRASE, b""
+ )
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(["command"])
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_encrypted_key_expected(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(
+ pytest.public_cert_bytes + pytest.private_key_bytes + PASSPHRASE, b""
+ )
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(
+ ["command"], expect_encrypted_key=True
+ )
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_unencrypted_key_expected(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(
+ pytest.public_cert_bytes + ENCRYPTED_EC_PRIVATE_KEY, b""
+ )
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(["command"])
@mock.patch("subprocess.Popen", autospec=True)
def test_cert_provider_returns_error(self, mock_popen):
mock_popen.return_value = self.create_mock_process(b"", b"some error")
mock_popen.return_value.returncode = 1
- with pytest.raises(RuntimeError):
- assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(["command"])
@mock.patch("subprocess.Popen", autospec=True)
def test_popen_raise_exception(self, mock_popen):
mock_popen.side_effect = OSError()
- with pytest.raises(OSError):
- assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper._run_cert_provider_command(["command"])
+
+
+class TestGetClientSslCredentials(object):
+ @mock.patch(
+ "google.auth.transport._mtls_helper._run_cert_provider_command", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_success(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_run_cert_provider_command,
+ ):
+ mock_check_dca_metadata_path.return_value = True
+ mock_read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["command"]
+ }
+ mock_run_cert_provider_command.return_value = (b"cert", b"key", None)
+ has_cert, cert, key, passphrase = _mtls_helper.get_client_ssl_credentials()
+ assert has_cert
+ assert cert == b"cert"
+ assert key == b"key"
+ assert passphrase is None
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_success_without_metadata(self, mock_check_dca_metadata_path):
+ mock_check_dca_metadata_path.return_value = False
+ has_cert, cert, key, passphrase = _mtls_helper.get_client_ssl_credentials()
+ assert not has_cert
+ assert cert is None
+ assert key is None
+ assert passphrase is None
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper._run_cert_provider_command", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_success_with_encrypted_key(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_run_cert_provider_command,
+ ):
+ mock_check_dca_metadata_path.return_value = True
+ mock_read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["command"]
+ }
+ mock_run_cert_provider_command.return_value = (b"cert", b"key", b"passphrase")
+ has_cert, cert, key, passphrase = _mtls_helper.get_client_ssl_credentials(
+ generate_encrypted_key=True
+ )
+ assert has_cert
+ assert cert == b"cert"
+ assert key == b"key"
+ assert passphrase == b"passphrase"
+ mock_run_cert_provider_command.assert_called_once_with(
+ ["command", "--with_passphrase"], expect_encrypted_key=True
+ )
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_missing_cert_command(
+ self, mock_check_dca_metadata_path, mock_read_dca_metadata_file
+ ):
+ mock_check_dca_metadata_path.return_value = True
+ mock_read_dca_metadata_file.return_value = {}
+ with pytest.raises(exceptions.ClientCertError):
+ _mtls_helper.get_client_ssl_credentials()
class TestGetClientCertAndKey(object):
@@ -198,32 +370,38 @@
assert key == pytest.private_key_bytes
@mock.patch(
- "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
- )
- def test_no_metadata(self, mock_check_dca_metadata_path):
- mock_check_dca_metadata_path.return_value = None
-
- found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key()
- assert not found_cert_key
-
- @mock.patch(
"google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
)
- @mock.patch(
- "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
- )
- def test_use_metadata(
- self, mock_check_dca_metadata_path, mock_get_client_ssl_credentials
- ):
- mock_check_dca_metadata_path.return_value = os.path.join(
- pytest.data_dir, "context_aware_metadata.json"
- )
+ def test_use_metadata(self, mock_get_client_ssl_credentials):
mock_get_client_ssl_credentials.return_value = (
+ True,
pytest.public_cert_bytes,
pytest.private_key_bytes,
+ None,
)
found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key()
assert found_cert_key
assert cert == pytest.public_cert_bytes
assert key == pytest.private_key_bytes
+
+
+class TestDecryptPrivateKey(object):
+ def test_success(self):
+ decrypted_key = _mtls_helper.decrypt_private_key(
+ ENCRYPTED_EC_PRIVATE_KEY, PASSPHRASE_VALUE
+ )
+ private_key = crypto.load_privatekey(crypto.FILETYPE_PEM, decrypted_key)
+ public_key = crypto.load_publickey(crypto.FILETYPE_PEM, EC_PUBLIC_KEY)
+ x509 = crypto.X509()
+ x509.set_pubkey(public_key)
+
+ # Test the decrypted key works by signing and verification.
+ signature = crypto.sign(private_key, b"data", "sha256")
+ crypto.verify(x509, signature, b"data", "sha256")
+
+ def test_crypto_error(self):
+ with pytest.raises(crypto.Error):
+ _mtls_helper.decrypt_private_key(
+ ENCRYPTED_EC_PRIVATE_KEY, b"wrong_password"
+ )
diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py
index 5c61f96..c3da76d 100644
--- a/tests/transport/test_grpc.py
+++ b/tests/transport/test_grpc.py
@@ -129,7 +129,12 @@
read_dca_metadata_file.return_value = {
"cert_provider_command": ["some command"]
}
- get_client_ssl_credentials.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
+ get_client_ssl_credentials.return_value = (
+ True,
+ PUBLIC_CERT_BYTES,
+ PRIVATE_KEY_BYTES,
+ None,
+ )
channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, request, target, options=mock.sentinel.options
@@ -314,7 +319,7 @@
}
# Mock that client cert and key are not loaded and exception is raised.
- mock_get_client_ssl_credentials.side_effect = ValueError()
+ mock_get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
with pytest.raises(exceptions.MutualTLSChannelError):
assert google.auth.transport.grpc.SslCredentials().ssl_credentials
@@ -331,8 +336,10 @@
"cert_provider_command": ["some command"]
}
mock_get_client_ssl_credentials.return_value = (
+ True,
PUBLIC_CERT_BYTES,
PRIVATE_KEY_BYTES,
+ None,
)
ssl_credentials = google.auth.transport.grpc.SslCredentials()
diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py
index ed388d4..77e1527 100644
--- a/tests/transport/test_requests.py
+++ b/tests/transport/test_requests.py
@@ -429,7 +429,7 @@
"google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
)
def test_configure_mtls_channel_exceptions(self, mock_get_client_cert_and_key):
- mock_get_client_cert_and_key.side_effect = ValueError()
+ mock_get_client_cert_and_key.side_effect = exceptions.ClientCertError()
auth_session = google.auth.transport.requests.AuthorizedSession(
credentials=mock.Mock()
diff --git a/tests/transport/test_urllib3.py b/tests/transport/test_urllib3.py
index a25fcd7..1a1c0a1 100644
--- a/tests/transport/test_urllib3.py
+++ b/tests/transport/test_urllib3.py
@@ -233,7 +233,7 @@
credentials=mock.Mock()
)
- mock_get_client_cert_and_key.side_effect = ValueError()
+ mock_get_client_cert_and_key.side_effect = exceptions.ClientCertError()
with pytest.raises(exceptions.MutualTLSChannelError):
authed_http.configure_mtls_channel()