feat: add helper func to for default encrypted cert (#514)
* feat: helper func to for default encrpted cert
diff --git a/google/auth/transport/mtls.py b/google/auth/transport/mtls.py
index 063b265..5b74230 100644
--- a/google/auth/transport/mtls.py
+++ b/google/auth/transport/mtls.py
@@ -58,3 +58,45 @@
return cert_bytes, key_bytes
return callback
+
+
+def default_client_encrypted_cert_source(cert_path, key_path):
+ """Get a callback which returns the default encrpyted client SSL credentials.
+
+ Args:
+ cert_path (str): The cert file path. The default client certificate will
+ be written to this file when the returned callback is called.
+ key_path (str): The key file path. The default encrypted client key will
+ be written to this file when the returned callback is called.
+
+ Returns:
+ Callable[[], [str, str, bytes]]: A callback which generates the default
+ client certificate, encrpyted private key and passphrase. It writes
+ the certificate and private key into the cert_path and key_path, and
+ returns the cert_path, key_path and passphrase bytes.
+
+ Raises:
+ google.auth.exceptions.DefaultClientCertSourceError: If any problem
+ occurs when loading or saving the client certificate and key.
+ """
+ if not has_default_client_cert_source():
+ raise exceptions.MutualTLSChannelError(
+ "Default client encrypted cert source doesn't exist"
+ )
+
+ def callback():
+ try:
+ _, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials(
+ generate_encrypted_key=True
+ )
+ with open(cert_path, "wb") as cert_file:
+ cert_file.write(cert_bytes)
+ with open(key_path, "wb") as key_file:
+ key_file.write(key_bytes)
+ except (exceptions.ClientCertError, OSError) as caught_exc:
+ new_exc = exceptions.MutualTLSChannelError(caught_exc)
+ six.raise_from(new_exc, caught_exc)
+
+ return cert_path, key_path, passphrase_bytes
+
+ return callback
diff --git a/tests/transport/test_mtls.py b/tests/transport/test_mtls.py
index d3bc391..ff70bb3 100644
--- a/tests/transport/test_mtls.py
+++ b/tests/transport/test_mtls.py
@@ -53,3 +53,31 @@
callback = mtls.default_client_cert_source()
with pytest.raises(exceptions.MutualTLSChannelError):
callback()
+
+
+@mock.patch(
+ "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+)
+@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True)
+def test_default_client_encrypted_cert_source(
+ has_default_client_cert_source, get_client_ssl_credentials
+):
+ # Test default client cert source doesn't exist.
+ has_default_client_cert_source.return_value = False
+ with pytest.raises(exceptions.MutualTLSChannelError):
+ mtls.default_client_encrypted_cert_source("cert_path", "key_path")
+
+ # The following tests will assume default client cert source exists.
+ has_default_client_cert_source.return_value = True
+
+ # Test good callback.
+ get_client_ssl_credentials.return_value = (True, b"cert", b"key", b"passphrase")
+ callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
+ with mock.patch("{}.open".format(__name__), return_value=mock.MagicMock()):
+ assert callback() == ("cert_path", "key_path", b"passphrase")
+
+ # Test bad callback which throws exception.
+ get_client_ssl_credentials.side_effect = exceptions.ClientCertError()
+ callback = mtls.default_client_encrypted_cert_source("cert_path", "key_path")
+ with pytest.raises(exceptions.MutualTLSChannelError):
+ callback()