feat: encrypted mtls private key support (#496)

(1) support encrypted private key decryption
(2) support reading encrypted key and passphrase from cert provider
command

Co-authored-by: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com>
diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py
index 5b61460..da06d86 100644
--- a/google/auth/exceptions.py
+++ b/google/auth/exceptions.py
@@ -39,3 +39,7 @@
 class MutualTLSChannelError(GoogleAuthError):
     """Used to indicate that mutual TLS channel creation is failed, or mutual
     TLS channel credentials is missing or invalid."""
+
+
+class ClientCertError(GoogleAuthError):
+    """Used to indicate that client certificate is missing or invalid."""
diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py
index c518cc8..388ae3c 100644
--- a/google/auth/transport/_mtls_helper.py
+++ b/google/auth/transport/_mtls_helper.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-"""Helper functions for getting mTLS cert and key, for internal use only."""
+"""Helper functions for getting mTLS cert and key."""
 
 import json
 import logging
@@ -20,6 +20,10 @@
 import re
 import subprocess
 
+import six
+
+from google.auth import exceptions
+
 CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
 _CERT_PROVIDER_COMMAND = "cert_provider_command"
 _CERT_REGEX = re.compile(
@@ -30,6 +34,7 @@
 # "-----BEGIN PRIVATE KEY-----...",
 # "-----BEGIN EC PRIVATE KEY-----...",
 # "-----BEGIN RSA PRIVATE KEY-----..."
+# "-----BEGIN ENCRYPTED PRIVATE KEY-----"
 _KEY_REGEX = re.compile(
     b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?",
     re.DOTALL,
@@ -38,6 +43,11 @@
 _LOGGER = logging.getLogger(__name__)
 
 
+_PASSPHRASE_REGEX = re.compile(
+    b"-----BEGIN PASSPHRASE-----(.+)-----END PASSPHRASE-----", re.DOTALL
+)
+
+
 def _check_dca_metadata_path(metadata_path):
     """Checks for context aware metadata. If it exists, returns the absolute path;
     otherwise returns None.
@@ -65,57 +75,109 @@
         Dict[str, str]: The metadata.
 
     Raises:
-        ValueError: If failed to parse metadata as JSON.
+        google.auth.exceptions.ClientCertError: If failed to parse metadata as JSON.
     """
-    with open(metadata_path) as f:
-        metadata = json.load(f)
+    try:
+        with open(metadata_path) as f:
+            metadata = json.load(f)
+    except ValueError as caught_exc:
+        new_exc = exceptions.ClientCertError(caught_exc)
+        six.raise_from(new_exc, caught_exc)
 
     return metadata
 
 
-def get_client_ssl_credentials(metadata_json):
-    """Returns the client side mTLS cert and key.
+def _run_cert_provider_command(command, expect_encrypted_key=False):
+    """Run the provided command, and return client side mTLS cert, key and
+    passphrase.
 
     Args:
-        metadata_json (Dict[str, str]): metadata JSON file which contains the cert
-            provider command.
+        command (List[str]): cert provider command.
+        expect_encrypted_key (bool): If encrypted private key is expected.
 
     Returns:
-        Tuple[bytes, bytes]: client certificate and key, both in PEM format.
+        Tuple[bytes, bytes, bytes]: client certificate bytes in PEM format, key
+            bytes in PEM format and passphrase bytes.
 
     Raises:
-        OSError: If the cert provider command failed to run.
-        RuntimeError: If the cert provider command has a runtime error.
-        ValueError: If the metadata json file doesn't contain the cert provider
-            command or if the command doesn't produce both the client certificate
-            and client key.
+        google.auth.exceptions.ClientCertError: if problems occurs when running
+            the cert provider command or generating cert, key and passphrase.
     """
-    # TODO: implement an in-memory cache of cert and key so we don't have to
-    # run cert provider command every time.
-
-    # Check the cert provider command existence in the metadata json file.
-    if _CERT_PROVIDER_COMMAND not in metadata_json:
-        raise ValueError("Cert provider command is not found")
-
-    # Execute the command. It throws OsError in case of system failure.
-    command = metadata_json[_CERT_PROVIDER_COMMAND]
-    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-    stdout, stderr = process.communicate()
+    try:
+        process = subprocess.Popen(
+            command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+        )
+        stdout, stderr = process.communicate()
+    except OSError as caught_exc:
+        new_exc = exceptions.ClientCertError(caught_exc)
+        six.raise_from(new_exc, caught_exc)
 
     # Check cert provider command execution error.
     if process.returncode != 0:
-        raise RuntimeError(
+        raise exceptions.ClientCertError(
             "Cert provider command returns non-zero status code %s" % process.returncode
         )
 
-    # Extract certificate (chain) and key.
+    # Extract certificate (chain), key and passphrase.
     cert_match = re.findall(_CERT_REGEX, stdout)
     if len(cert_match) != 1:
-        raise ValueError("Client SSL certificate is missing or invalid")
+        raise exceptions.ClientCertError("Client SSL certificate is missing or invalid")
     key_match = re.findall(_KEY_REGEX, stdout)
     if len(key_match) != 1:
-        raise ValueError("Client SSL key is missing or invalid")
-    return cert_match[0], key_match[0]
+        raise exceptions.ClientCertError("Client SSL key is missing or invalid")
+    passphrase_match = re.findall(_PASSPHRASE_REGEX, stdout)
+
+    if expect_encrypted_key:
+        if len(passphrase_match) != 1:
+            raise exceptions.ClientCertError("Passphrase is missing or invalid")
+        if b"ENCRYPTED" not in key_match[0]:
+            raise exceptions.ClientCertError("Encrypted private key is expected")
+        return cert_match[0], key_match[0], passphrase_match[0].strip()
+
+    if b"ENCRYPTED" in key_match[0]:
+        raise exceptions.ClientCertError("Encrypted private key is not expected")
+    if len(passphrase_match) > 0:
+        raise exceptions.ClientCertError("Passphrase is not expected")
+    return cert_match[0], key_match[0], None
+
+
+def get_client_ssl_credentials(generate_encrypted_key=False):
+    """Returns the client side certificate, private key and passphrase.
+
+    Args:
+        generate_encrypted_key (bool): If set to True, encrypted private key
+            and passphrase will be generated; otherwise, unencrypted private key
+            will be generated and passphrase will be None.
+
+    Returns:
+        Tuple[bool, bytes, bytes, bytes]:
+            A boolean indicating if cert, key and passphrase are obtained, the
+            cert bytes and key bytes both in PEM format, and passphrase bytes.
+
+    Raises:
+        google.auth.exceptions.ClientCertError: if problems occurs when getting
+            the cert, key and passphrase.
+    """
+    metadata_path = _check_dca_metadata_path(CONTEXT_AWARE_METADATA_PATH)
+
+    if metadata_path:
+        metadata_json = _read_dca_metadata_file(metadata_path)
+
+        if _CERT_PROVIDER_COMMAND not in metadata_json:
+            raise exceptions.ClientCertError("Cert provider command is not found")
+
+        command = metadata_json[_CERT_PROVIDER_COMMAND]
+
+        if generate_encrypted_key and "--with_passphrase" not in command:
+            command.append("--with_passphrase")
+
+        # Execute the command.
+        cert, key, passphrase = _run_cert_provider_command(
+            command, expect_encrypted_key=generate_encrypted_key
+        )
+        return True, cert, key, passphrase
+
+    return False, None, None, None
 
 
 def get_client_cert_and_key(client_cert_callback=None):
@@ -135,20 +197,54 @@
             and key bytes both in PEM format.
 
     Raises:
-        OSError: If the cert provider command failed to run.
-        RuntimeError: If the cert provider command has a runtime error.
-        ValueError: If the metadata json file doesn't contain the cert provider
-            command or if the command doesn't produce both the client certificate
-            and client key.
+        google.auth.exceptions.ClientCertError: if problems occurs when getting
+            the cert and key.
     """
     if client_cert_callback:
         cert, key = client_cert_callback()
         return True, cert, key
 
-    metadata_path = _check_dca_metadata_path(CONTEXT_AWARE_METADATA_PATH)
-    if metadata_path:
-        metadata = _read_dca_metadata_file(metadata_path)
-        cert, key = get_client_ssl_credentials(metadata)
-        return True, cert, key
+    has_cert, cert, key, _ = get_client_ssl_credentials(generate_encrypted_key=False)
+    return has_cert, cert, key
 
-    return False, None, None
+
+def decrypt_private_key(key, passphrase):
+    """A helper function to decrypt the private key with the given passphrase.
+    google-auth library doesn't support passphrase protected private key for
+    mutual TLS channel. This helper function can be used to decrypt the
+    passphrase protected private key in order to estalish mutual TLS channel.
+
+    For example, if you have a function which produces client cert, passphrase
+    protected private key and passphrase, you can convert it to a client cert
+    callback function accepted by google-auth::
+
+        from google.auth.transport import _mtls_helper
+
+        def your_client_cert_function():
+            return cert, encrypted_key, passphrase
+
+        # callback accepted by google-auth for mutual TLS channel.
+        def client_cert_callback():
+            cert, encrypted_key, passphrase = your_client_cert_function()
+            decrypted_key = _mtls_helper.decrypt_private_key(encrypted_key,
+                passphrase)
+            return cert, decrypted_key
+
+    Args:
+        key (bytes): The private key bytes in PEM format.
+        passphrase (bytes): The passphrase bytes.
+
+    Returns:
+        bytes: The decrypted private key in PEM format.
+
+    Raises:
+        ImportError: If pyOpenSSL is not installed.
+        OpenSSL.crypto.Error: If there is any problem decrypting the private key.
+    """
+    from OpenSSL import crypto
+
+    # First convert encrypted_key_bytes to PKey object
+    pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key, passphrase=passphrase)
+
+    # Then dump the decrypted key bytes
+    return crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey)
diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py
index d62c415..13234a3 100644
--- a/google/auth/transport/grpc.py
+++ b/google/auth/transport/grpc.py
@@ -264,13 +264,10 @@
 
     def __init__(self):
         # Load client SSL credentials.
-        self._context_aware_metadata_path = _mtls_helper._check_dca_metadata_path(
+        metadata_path = _mtls_helper._check_dca_metadata_path(
             _mtls_helper.CONTEXT_AWARE_METADATA_PATH
         )
-        if self._context_aware_metadata_path:
-            self._is_mtls = True
-        else:
-            self._is_mtls = False
+        self._is_mtls = metadata_path is not None
 
     @property
     def ssl_credentials(self):
@@ -288,16 +285,13 @@
             google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel
                 creation failed for any reason.
         """
-        if self._context_aware_metadata_path:
+        if self._is_mtls:
             try:
-                metadata = _mtls_helper._read_dca_metadata_file(
-                    self._context_aware_metadata_path
-                )
-                cert, key = _mtls_helper.get_client_ssl_credentials(metadata)
+                _, cert, key, _ = _mtls_helper.get_client_ssl_credentials()
                 self._ssl_credentials = grpc.ssl_channel_credentials(
                     certificate_chain=cert, private_key=key
                 )
-            except (OSError, RuntimeError, ValueError) as caught_exc:
+            except exceptions.ClientCertError as caught_exc:
                 new_exc = exceptions.MutualTLSChannelError(caught_exc)
                 six.raise_from(new_exc, caught_exc)
         else:
diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py
index cc0e93b..9f55bea 100644
--- a/google/auth/transport/requests.py
+++ b/google/auth/transport/requests.py
@@ -373,11 +373,9 @@
                 mtls_adapter = _MutualTlsAdapter(cert, key)
                 self.mount("https://", mtls_adapter)
         except (
+            exceptions.ClientCertError,
             ImportError,
             OpenSSL.crypto.Error,
-            OSError,
-            RuntimeError,
-            ValueError,
         ) as caught_exc:
             new_exc = exceptions.MutualTLSChannelError(caught_exc)
             six.raise_from(new_exc, caught_exc)
diff --git a/google/auth/transport/urllib3.py b/google/auth/transport/urllib3.py
index 3771d84..3742f1a 100644
--- a/google/auth/transport/urllib3.py
+++ b/google/auth/transport/urllib3.py
@@ -316,11 +316,9 @@
             else:
                 self.http = _make_default_http()
         except (
+            exceptions.ClientCertError,
             ImportError,
             OpenSSL.crypto.Error,
-            OSError,
-            RuntimeError,
-            ValueError,
         ) as caught_exc:
             new_exc = exceptions.MutualTLSChannelError(caught_exc)
             six.raise_from(new_exc, caught_exc)