feat: add SslCredentials class for mTLS ADC (#448)

feat: add SslCredentials class for mTLS ADC (linux)
diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py
new file mode 100644
index 0000000..1ce9fa5
--- /dev/null
+++ b/google/auth/transport/_mtls_helper.py
@@ -0,0 +1,116 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper functions for getting mTLS cert and key, for internal use only."""
+
+import json
+import logging
+from os import path
+import re
+import subprocess
+
+CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
+_CERT_PROVIDER_COMMAND = "cert_provider_command"
+_CERT_REGEX = re.compile(
+    b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL
+)
+
+# support various format of key files, e.g.
+# "-----BEGIN PRIVATE KEY-----...",
+# "-----BEGIN EC PRIVATE KEY-----...",
+# "-----BEGIN RSA PRIVATE KEY-----..."
+_KEY_REGEX = re.compile(
+    b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?",
+    re.DOTALL,
+)
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def _check_dca_metadata_path(metadata_path):
+    """Checks for context aware metadata. If it exists, returns the absolute path;
+    otherwise returns None.
+
+    Args:
+        metadata_path (str): context aware metadata path.
+
+    Returns:
+        str: absolute path if exists and None otherwise.
+    """
+    metadata_path = path.expanduser(metadata_path)
+    if not path.exists(metadata_path):
+        _LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path)
+        return None
+    return metadata_path
+
+
+def _read_dca_metadata_file(metadata_path):
+    """Loads context aware metadata from the given path.
+
+    Args:
+        metadata_path (str): context aware metadata path.
+
+    Returns:
+        Dict[str, str]: The metadata.
+
+    Raises:
+        ValueError: If failed to parse metadata as JSON.
+    """
+    with open(metadata_path) as f:
+        metadata = json.load(f)
+
+    return metadata
+
+
+def get_client_ssl_credentials(metadata_json):
+    """Returns the client side mTLS cert and key.
+
+    Args:
+        metadata_json (Dict[str, str]): metadata JSON file which contains the cert
+            provider command.
+
+    Returns:
+        Tuple[bytes, bytes]: client certificate and key, both in PEM format.
+
+    Raises:
+        OSError: If the cert provider command failed to run.
+        RuntimeError: If the cert provider command has a runtime error.
+        ValueError: If the metadata json file doesn't contain the cert provider command or if the command doesn't produce both the client certificate and client key.
+    """
+    # TODO: implement an in-memory cache of cert and key so we don't have to
+    # run cert provider command every time.
+
+    # Check the cert provider command existence in the metadata json file.
+    if _CERT_PROVIDER_COMMAND not in metadata_json:
+        raise ValueError("Cert provider command is not found")
+
+    # Execute the command. It throws OsError in case of system failure.
+    command = metadata_json[_CERT_PROVIDER_COMMAND]
+    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+    stdout, stderr = process.communicate()
+
+    # Check cert provider command execution error.
+    if process.returncode != 0:
+        raise RuntimeError(
+            "Cert provider command returns non-zero status code %s" % process.returncode
+        )
+
+    # Extract certificate (chain) and key.
+    cert_match = re.findall(_CERT_REGEX, stdout)
+    if len(cert_match) != 1:
+        raise ValueError("Client SSL certificate is missing or invalid")
+    key_match = re.findall(_KEY_REGEX, stdout)
+    if len(key_match) != 1:
+        raise ValueError("Client SSL key is missing or invalid")
+    return cert_match[0], key_match[0]
diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py
index fb90fbb..ca38739 100644
--- a/google/auth/transport/grpc.py
+++ b/google/auth/transport/grpc.py
@@ -17,9 +17,12 @@
 from __future__ import absolute_import
 
 from concurrent import futures
+import logging
 
 import six
 
+from google.auth.transport import _mtls_helper
+
 try:
     import grpc
 except ImportError as caught_exc:  # pragma: NO COVER
@@ -31,6 +34,8 @@
         caught_exc,
     )
 
+_LOGGER = logging.getLogger(__name__)
+
 
 class AuthMetadataPlugin(grpc.AuthMetadataPlugin):
     """A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each
@@ -92,7 +97,12 @@
 
 
 def secure_authorized_channel(
-    credentials, request, target, ssl_credentials=None, **kwargs
+    credentials,
+    request,
+    target,
+    ssl_credentials=None,
+    client_cert_callback=None,
+    **kwargs
 ):
     """Creates a secure authorized gRPC channel.
 
@@ -114,11 +124,86 @@
 
         # Create a channel.
         channel = google.auth.transport.grpc.secure_authorized_channel(
-            credentials, 'speech.googleapis.com:443', request)
+            credentials, regular_endpoint, request,
+            ssl_credentials=grpc.ssl_channel_credentials())
 
         # Use the channel to create a stub.
         cloud_speech.create_Speech_stub(channel)
 
+    Usage:
+
+    There are actually a couple of options to create a channel, depending on if
+    you want to create a regular or mutual TLS channel.
+
+    First let's list the endpoints (regular vs mutual TLS) to choose from::
+
+        regular_endpoint = 'speech.googleapis.com:443'
+        mtls_endpoint = 'speech.mtls.googleapis.com:443'
+
+    Option 1: create a regular (non-mutual) TLS channel by explicitly setting
+    the ssl_credentials::
+
+        regular_ssl_credentials = grpc.ssl_channel_credentials()
+
+        channel = google.auth.transport.grpc.secure_authorized_channel(
+            credentials, regular_endpoint, request,
+            ssl_credentials=regular_ssl_credentials)
+
+    Option 2: create a mutual TLS channel by calling a callback which returns
+    the client side certificate and the key::
+
+        def my_client_cert_callback():
+            code_to_load_client_cert_and_key()
+            if loaded:
+                return (pem_cert_bytes, pem_key_bytes)
+            raise MyClientCertFailureException()
+
+        try:
+            channel = google.auth.transport.grpc.secure_authorized_channel(
+                credentials, mtls_endpoint, request,
+                client_cert_callback=my_client_cert_callback)
+        except MyClientCertFailureException:
+            # handle the exception
+
+    Option 3: use application default SSL credentials. It searches and uses
+    the command in a context aware metadata file, which is available on devices
+    with endpoint verification support.
+    See https://cloud.google.com/endpoint-verification/docs/overview::
+
+        try:
+            default_ssl_credentials = SslCredentials()
+        except:
+            # Exception can be raised if the context aware metadata is malformed.
+            # See :class:`SslCredentials` for the possible exceptions.
+
+        # Choose the endpoint based on the SSL credentials type.
+        if default_ssl_credentials.is_mtls:
+            endpoint_to_use = mtls_endpoint
+        else:
+            endpoint_to_use = regular_endpoint
+        channel = google.auth.transport.grpc.secure_authorized_channel(
+            credentials, endpoint_to_use, request,
+            ssl_credentials=default_ssl_credentials)
+
+    Option 4: not setting ssl_credentials and client_cert_callback. For devices
+    without endpoint verification support, a regular TLS channel is created;
+    otherwise, a mutual TLS channel is created, however, the call should be
+    wrapped in a try/except block in case of malformed context aware metadata.
+
+    The following code uses regular_endpoint, it works the same no matter the
+    created channle is regular or mutual TLS. Regular endpoint ignores client
+    certificate and key::
+
+        channel = google.auth.transport.grpc.secure_authorized_channel(
+            credentials, regular_endpoint, request)
+
+    The following code uses mtls_endpoint, if the created channle is regular,
+    and API mtls_endpoint is confgured to require client SSL credentials, API
+    calls using this channel will be rejected::
+
+        channel = google.auth.transport.grpc.secure_authorized_channel(
+            credentials, mtls_endpoint, request)
+
     Args:
         credentials (google.auth.credentials.Credentials): The credentials to
             add to requests.
@@ -129,10 +214,33 @@
         target (str): The host and port of the service.
         ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
             credentials. This can be used to specify different certificates.
+            This argument is mutually exclusive with client_cert_callback;
+            providing both will raise an exception.
+            If ssl_credentials and client_cert_callback are None, application
+            default SSL credentials will be used.
+        client_cert_callback (Callable[[], (bytes, bytes)]): Optional
+            callback function to obtain client certicate and key for mutual TLS
+            connection. This argument is mutually exclusive with
+            ssl_credentials; providing both will raise an exception.
+            If ssl_credentials and client_cert_callback are None, application
+            default SSL credentials will be used.
         kwargs: Additional arguments to pass to :func:`grpc.secure_channel`.
 
     Returns:
         grpc.Channel: The created gRPC channel.
+
+    Raises:
+        OSError: If the cert provider command launch fails during the application
+            default SSL credentials loading process on devices with endpoint
+            verification support.
+        RuntimeError: If the cert provider command has a runtime error during the
+            application default SSL credentials loading process on devices with
+            endpoint verification support.
+        ValueError:
+            If the context aware metadata file is malformed or if the cert provider
+            command doesn't produce both client certificate and key during the
+            application default SSL credentials loading process on devices with
+            endpoint verification support.
     """
     # Create the metadata plugin for inserting the authorization header.
     metadata_plugin = AuthMetadataPlugin(credentials, request)
@@ -140,8 +248,24 @@
     # Create a set of grpc.CallCredentials using the metadata plugin.
     google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
 
-    if ssl_credentials is None:
-        ssl_credentials = grpc.ssl_channel_credentials()
+    if ssl_credentials and client_cert_callback:
+        raise ValueError(
+            "Received both ssl_credentials and client_cert_callback; "
+            "these are mutually exclusive."
+        )
+
+    # If SSL credentials are not explicitly set, try client_cert_callback and ADC.
+    if not ssl_credentials:
+        if client_cert_callback:
+            # Use the callback if provided.
+            cert, key = client_cert_callback()
+            ssl_credentials = grpc.ssl_channel_credentials(
+                certificate_chain=cert, private_key=key
+            )
+        else:
+            # Use application default SSL credentials.
+            adc_ssl_credentils = SslCredentials()
+            ssl_credentials = adc_ssl_credentils.ssl_credentials
 
     # Combine the ssl credentials and the authorization credentials.
     composite_credentials = grpc.composite_channel_credentials(
@@ -149,3 +273,59 @@
     )
 
     return grpc.secure_channel(target, composite_credentials, **kwargs)
+
+
+class SslCredentials:
+    """Class for application default SSL credentials.
+
+    For devices with endpoint verification support, a device certificate will be
+    automatically loaded and mutual TLS will be established.
+    See https://cloud.google.com/endpoint-verification/docs/overview.
+    """
+
+    def __init__(self):
+        # Load client SSL credentials.
+        self._context_aware_metadata_path = _mtls_helper._check_dca_metadata_path(
+            _mtls_helper.CONTEXT_AWARE_METADATA_PATH
+        )
+        if self._context_aware_metadata_path:
+            self._is_mtls = True
+        else:
+            self._is_mtls = False
+
+    @property
+    def ssl_credentials(self):
+        """Get the created SSL channel credentials.
+
+        For devices with endpoint verification support, if the device certificate
+        loading has any problems, corresponding exceptions will be raised. For
+        a device without endpoint verification support, no exceptions will be
+        raised.
+
+        Returns:
+            grpc.ChannelCredentials: The created grpc channel credentials.
+
+        Raises:
+            OSError: If the cert provider command launch fails.
+            RuntimeError: If the cert provider command has a runtime error.
+            ValueError:
+                If the context aware metadata file is malformed or if the cert provider
+                command doesn't produce both the client certificate and key.
+        """
+        if self._context_aware_metadata_path:
+            metadata = _mtls_helper._read_dca_metadata_file(
+                self._context_aware_metadata_path
+            )
+            cert, key = _mtls_helper.get_client_ssl_credentials(metadata)
+            self._ssl_credentials = grpc.ssl_channel_credentials(
+                certificate_chain=cert, private_key=key
+            )
+        else:
+            self._ssl_credentials = grpc.ssl_channel_credentials()
+
+        return self._ssl_credentials
+
+    @property
+    def is_mtls(self):
+        """Indicates if the created SSL channel credentials is mutual TLS."""
+        return self._is_mtls
diff --git a/tests/data/context_aware_metadata.json b/tests/data/context_aware_metadata.json
new file mode 100644
index 0000000..ec40e78
--- /dev/null
+++ b/tests/data/context_aware_metadata.json
@@ -0,0 +1,6 @@
+{
+  "cert_provider_command":[
+    "/opt/google/endpoint-verification/bin/SecureConnectHelper",
+    "--print_certificate"],
+  "device_resource_ids":["11111111-1111-1111"]
+}
diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py
new file mode 100644
index 0000000..6e7175f
--- /dev/null
+++ b/tests/transport/test__mtls_helper.py
@@ -0,0 +1,177 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+
+import mock
+import pytest
+
+from google.auth.transport import _mtls_helper
+
+DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
+
+with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
+    PRIVATE_KEY_BYTES = fh.read()
+
+with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
+    PUBLIC_CERT_BYTES = fh.read()
+
+CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]}
+
+CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {}
+
+
+def check_cert_and_key(content, expected_cert, expected_key):
+    success = True
+
+    cert_match = re.findall(_mtls_helper._CERT_REGEX, content)
+    success = success and len(cert_match) == 1 and cert_match[0] == expected_cert
+
+    key_match = re.findall(_mtls_helper._KEY_REGEX, content)
+    success = success and len(key_match) == 1 and key_match[0] == expected_key
+
+    return success
+
+
+class TestCertAndKeyRegex(object):
+    def test_cert_and_key(self):
+        # Test single cert and single key
+        check_cert_and_key(
+            PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES
+        )
+        check_cert_and_key(
+            PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES
+        )
+
+        # Test cert chain and single key
+        check_cert_and_key(
+            PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES,
+            PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
+            PRIVATE_KEY_BYTES,
+        )
+        check_cert_and_key(
+            PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
+            PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
+            PRIVATE_KEY_BYTES,
+        )
+
+    def test_key(self):
+        # Create some fake keys for regex check.
+        KEY = b"""-----BEGIN PRIVATE KEY-----
+        MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg
+        /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB
+        -----END PRIVATE KEY-----"""
+        RSA_KEY = b"""-----BEGIN RSA PRIVATE KEY-----
+        MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg
+        /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB
+        -----END RSA PRIVATE KEY-----"""
+        EC_KEY = b"""-----BEGIN EC PRIVATE KEY-----
+        MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg
+        /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB
+        -----END EC PRIVATE KEY-----"""
+
+        check_cert_and_key(PUBLIC_CERT_BYTES + KEY, PUBLIC_CERT_BYTES, KEY)
+        check_cert_and_key(PUBLIC_CERT_BYTES + RSA_KEY, PUBLIC_CERT_BYTES, RSA_KEY)
+        check_cert_and_key(PUBLIC_CERT_BYTES + EC_KEY, PUBLIC_CERT_BYTES, EC_KEY)
+
+
+class TestCheckaMetadataPath(object):
+    def test_success(self):
+        metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json")
+        returned_path = _mtls_helper._check_dca_metadata_path(metadata_path)
+        assert returned_path is not None
+
+    def test_failure(self):
+        metadata_path = os.path.join(DATA_DIR, "not_exists.json")
+        returned_path = _mtls_helper._check_dca_metadata_path(metadata_path)
+        assert returned_path is None
+
+
+class TestReadMetadataFile(object):
+    def test_success(self):
+        metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json")
+        metadata = _mtls_helper._read_dca_metadata_file(metadata_path)
+
+        assert "cert_provider_command" in metadata
+
+    def test_file_not_json(self):
+        # read a file which is not json format.
+        metadata_path = os.path.join(DATA_DIR, "privatekey.pem")
+        with pytest.raises(ValueError):
+            _mtls_helper._read_dca_metadata_file(metadata_path)
+
+
+class TestGetClientSslCredentials(object):
+    def create_mock_process(self, output, error):
+        # There are two steps to execute a script with subprocess.Popen.
+        # (1) process = subprocess.Popen([comannds])
+        # (2) stdout, stderr = process.communicate()
+        # This function creates a mock process which can be returned by a mock
+        # subprocess.Popen. The mock process returns the given output and error
+        # when mock_process.communicate() is called.
+        mock_process = mock.Mock()
+        attrs = {"communicate.return_value": (output, error), "returncode": 0}
+        mock_process.configure_mock(**attrs)
+        return mock_process
+
+    @mock.patch("subprocess.Popen", autospec=True)
+    def test_success(self, mock_popen):
+        mock_popen.return_value = self.create_mock_process(
+            PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, b""
+        )
+        cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+        assert cert == PUBLIC_CERT_BYTES
+        assert key == PRIVATE_KEY_BYTES
+
+    @mock.patch("subprocess.Popen", autospec=True)
+    def test_success_with_cert_chain(self, mock_popen):
+        PUBLIC_CERT_CHAIN_BYTES = PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES
+        mock_popen.return_value = self.create_mock_process(
+            PUBLIC_CERT_CHAIN_BYTES + PRIVATE_KEY_BYTES, b""
+        )
+        cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+        assert cert == PUBLIC_CERT_CHAIN_BYTES
+        assert key == PRIVATE_KEY_BYTES
+
+    def test_missing_cert_provider_command(self):
+        with pytest.raises(ValueError):
+            assert _mtls_helper.get_client_ssl_credentials(
+                CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND
+            )
+
+    @mock.patch("subprocess.Popen", autospec=True)
+    def test_missing_cert(self, mock_popen):
+        mock_popen.return_value = self.create_mock_process(PRIVATE_KEY_BYTES, b"")
+        with pytest.raises(ValueError):
+            assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+
+    @mock.patch("subprocess.Popen", autospec=True)
+    def test_missing_key(self, mock_popen):
+        mock_popen.return_value = self.create_mock_process(PUBLIC_CERT_BYTES, b"")
+        with pytest.raises(ValueError):
+            assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+
+    @mock.patch("subprocess.Popen", autospec=True)
+    def test_cert_provider_returns_error(self, mock_popen):
+        mock_popen.return_value = self.create_mock_process(b"", b"some error")
+        mock_popen.return_value.returncode = 1
+        with pytest.raises(RuntimeError):
+            assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+
+    @mock.patch("subprocess.Popen", autospec=True)
+    def test_popen_raise_exception(self, mock_popen):
+        mock_popen.side_effect = OSError()
+        with pytest.raises(OSError):
+            assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py
index 857c32b..23e62a2 100644
--- a/tests/transport/test_grpc.py
+++ b/tests/transport/test_grpc.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import datetime
+import os
 import time
 
 import mock
@@ -31,6 +32,12 @@
 except ImportError:  # pragma: NO COVER
     HAS_GRPC = False
 
+DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
+METADATA_PATH = os.path.join(DATA_DIR, "context_aware_metadata.json")
+with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
+    PRIVATE_KEY_BYTES = fh.read()
+with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
+    PUBLIC_CERT_BYTES = fh.read()
 
 pytestmark = pytest.mark.skipif(not HAS_GRPC, reason="gRPC is unavailable.")
 
@@ -87,70 +94,251 @@
         )
 
 
+@mock.patch(
+    "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+)
 @mock.patch("grpc.composite_channel_credentials", autospec=True)
 @mock.patch("grpc.metadata_call_credentials", autospec=True)
 @mock.patch("grpc.ssl_channel_credentials", autospec=True)
 @mock.patch("grpc.secure_channel", autospec=True)
-def test_secure_authorized_channel(
-    secure_channel,
-    ssl_channel_credentials,
-    metadata_call_credentials,
-    composite_channel_credentials,
-):
-    credentials = CredentialsStub()
-    request = mock.create_autospec(transport.Request)
-    target = "example.com:80"
-
-    channel = google.auth.transport.grpc.secure_authorized_channel(
-        credentials, request, target, options=mock.sentinel.options
+class TestSecureAuthorizedChannel(object):
+    @mock.patch(
+        "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
     )
-
-    # Check the auth plugin construction.
-    auth_plugin = metadata_call_credentials.call_args[0][0]
-    assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
-    assert auth_plugin._credentials == credentials
-    assert auth_plugin._request == request
-
-    # Check the ssl channel call.
-    assert ssl_channel_credentials.called
-
-    # Check the composite credentials call.
-    composite_channel_credentials.assert_called_once_with(
-        ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+    @mock.patch(
+        "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
     )
+    def test_secure_authorized_channel_adc(
+        self,
+        check_dca_metadata_path,
+        read_dca_metadata_file,
+        secure_channel,
+        ssl_channel_credentials,
+        metadata_call_credentials,
+        composite_channel_credentials,
+        get_client_ssl_credentials,
+    ):
+        credentials = CredentialsStub()
+        request = mock.create_autospec(transport.Request)
+        target = "example.com:80"
 
-    # Check the channel call.
-    secure_channel.assert_called_once_with(
-        target,
-        composite_channel_credentials.return_value,
-        options=mock.sentinel.options,
+        # Mock the context aware metadata and client cert/key so mTLS SSL channel
+        # will be used.
+        check_dca_metadata_path.return_value = METADATA_PATH
+        read_dca_metadata_file.return_value = {
+            "cert_provider_command": ["some command"]
+        }
+        get_client_ssl_credentials.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
+
+        channel = google.auth.transport.grpc.secure_authorized_channel(
+            credentials, request, target, options=mock.sentinel.options
+        )
+
+        # Check the auth plugin construction.
+        auth_plugin = metadata_call_credentials.call_args[0][0]
+        assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
+        assert auth_plugin._credentials == credentials
+        assert auth_plugin._request == request
+
+        # Check the ssl channel call.
+        ssl_channel_credentials.assert_called_once_with(
+            certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+        )
+
+        # Check the composite credentials call.
+        composite_channel_credentials.assert_called_once_with(
+            ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+        )
+
+        # Check the channel call.
+        secure_channel.assert_called_once_with(
+            target,
+            composite_channel_credentials.return_value,
+            options=mock.sentinel.options,
+        )
+        assert channel == secure_channel.return_value
+
+    def test_secure_authorized_channel_explicit_ssl(
+        self,
+        secure_channel,
+        ssl_channel_credentials,
+        metadata_call_credentials,
+        composite_channel_credentials,
+        get_client_ssl_credentials,
+    ):
+        credentials = mock.Mock()
+        request = mock.Mock()
+        target = "example.com:80"
+        ssl_credentials = mock.Mock()
+
+        google.auth.transport.grpc.secure_authorized_channel(
+            credentials, request, target, ssl_credentials=ssl_credentials
+        )
+
+        # Since explicit SSL credentials are provided, get_client_ssl_credentials
+        # shouldn't be called.
+        assert not get_client_ssl_credentials.called
+
+        # Check the ssl channel call.
+        assert not ssl_channel_credentials.called
+
+        # Check the composite credentials call.
+        composite_channel_credentials.assert_called_once_with(
+            ssl_credentials, metadata_call_credentials.return_value
+        )
+
+    def test_secure_authorized_channel_mutual_exclusive(
+        self,
+        secure_channel,
+        ssl_channel_credentials,
+        metadata_call_credentials,
+        composite_channel_credentials,
+        get_client_ssl_credentials,
+    ):
+        credentials = mock.Mock()
+        request = mock.Mock()
+        target = "example.com:80"
+        ssl_credentials = mock.Mock()
+        client_cert_callback = mock.Mock()
+
+        with pytest.raises(ValueError):
+            google.auth.transport.grpc.secure_authorized_channel(
+                credentials,
+                request,
+                target,
+                ssl_credentials=ssl_credentials,
+                client_cert_callback=client_cert_callback,
+            )
+
+    def test_secure_authorized_channel_with_client_cert_callback_success(
+        self,
+        secure_channel,
+        ssl_channel_credentials,
+        metadata_call_credentials,
+        composite_channel_credentials,
+        get_client_ssl_credentials,
+    ):
+        credentials = mock.Mock()
+        request = mock.Mock()
+        target = "example.com:80"
+        client_cert_callback = mock.Mock()
+        client_cert_callback.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
+
+        google.auth.transport.grpc.secure_authorized_channel(
+            credentials, request, target, client_cert_callback=client_cert_callback
+        )
+
+        client_cert_callback.assert_called_once()
+
+        # Check we are using the cert and key provided by client_cert_callback.
+        ssl_channel_credentials.assert_called_once_with(
+            certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+        )
+
+        # Check the composite credentials call.
+        composite_channel_credentials.assert_called_once_with(
+            ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+        )
+
+    @mock.patch(
+        "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
     )
-    assert channel == secure_channel.return_value
+    @mock.patch(
+        "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+    )
+    def test_secure_authorized_channel_with_client_cert_callback_failure(
+        self,
+        check_dca_metadata_path,
+        read_dca_metadata_file,
+        secure_channel,
+        ssl_channel_credentials,
+        metadata_call_credentials,
+        composite_channel_credentials,
+        get_client_ssl_credentials,
+    ):
+        credentials = mock.Mock()
+        request = mock.Mock()
+        target = "example.com:80"
+
+        client_cert_callback = mock.Mock()
+        client_cert_callback.side_effect = Exception("callback exception")
+
+        with pytest.raises(Exception) as excinfo:
+            google.auth.transport.grpc.secure_authorized_channel(
+                credentials, request, target, client_cert_callback=client_cert_callback
+            )
+
+        assert str(excinfo.value) == "callback exception"
 
 
-@mock.patch("grpc.composite_channel_credentials", autospec=True)
-@mock.patch("grpc.metadata_call_credentials", autospec=True)
 @mock.patch("grpc.ssl_channel_credentials", autospec=True)
-@mock.patch("grpc.secure_channel", autospec=True)
-def test_secure_authorized_channel_explicit_ssl(
-    secure_channel,
-    ssl_channel_credentials,
-    metadata_call_credentials,
-    composite_channel_credentials,
-):
-    credentials = mock.Mock()
-    request = mock.Mock()
-    target = "example.com:80"
-    ssl_credentials = mock.Mock()
+@mock.patch(
+    "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+)
+@mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True)
+@mock.patch(
+    "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+)
+class TestSslCredentials(object):
+    def test_no_context_aware_metadata(
+        self,
+        mock_check_dca_metadata_path,
+        mock_read_dca_metadata_file,
+        mock_get_client_ssl_credentials,
+        mock_ssl_channel_credentials,
+    ):
+        # Mock that the metadata file doesn't exist.
+        mock_check_dca_metadata_path.return_value = None
 
-    google.auth.transport.grpc.secure_authorized_channel(
-        credentials, request, target, ssl_credentials=ssl_credentials
-    )
+        ssl_credentials = google.auth.transport.grpc.SslCredentials()
 
-    # Check the ssl channel call.
-    assert not ssl_channel_credentials.called
+        # Since no context aware metadata is found, we wouldn't call
+        # get_client_ssl_credentials, and the SSL channel credentials created is
+        # non mTLS.
+        assert ssl_credentials.ssl_credentials is not None
+        assert not ssl_credentials.is_mtls
+        mock_get_client_ssl_credentials.assert_not_called()
+        mock_ssl_channel_credentials.assert_called_once_with()
 
-    # Check the composite credentials call.
-    composite_channel_credentials.assert_called_once_with(
-        ssl_credentials, metadata_call_credentials.return_value
-    )
+    def test_get_client_ssl_credentials_failure(
+        self,
+        mock_check_dca_metadata_path,
+        mock_read_dca_metadata_file,
+        mock_get_client_ssl_credentials,
+        mock_ssl_channel_credentials,
+    ):
+        mock_check_dca_metadata_path.return_value = METADATA_PATH
+        mock_read_dca_metadata_file.return_value = {
+            "cert_provider_command": ["some command"]
+        }
+
+        # Mock that client cert and key are not loaded and exception is raised.
+        mock_get_client_ssl_credentials.side_effect = ValueError()
+
+        with pytest.raises(ValueError):
+            assert google.auth.transport.grpc.SslCredentials().ssl_credentials
+
+    def test_get_client_ssl_credentials_success(
+        self,
+        mock_check_dca_metadata_path,
+        mock_read_dca_metadata_file,
+        mock_get_client_ssl_credentials,
+        mock_ssl_channel_credentials,
+    ):
+        mock_check_dca_metadata_path.return_value = METADATA_PATH
+        mock_read_dca_metadata_file.return_value = {
+            "cert_provider_command": ["some command"]
+        }
+        mock_get_client_ssl_credentials.return_value = (
+            PUBLIC_CERT_BYTES,
+            PRIVATE_KEY_BYTES,
+        )
+
+        ssl_credentials = google.auth.transport.grpc.SslCredentials()
+
+        assert ssl_credentials.ssl_credentials is not None
+        assert ssl_credentials.is_mtls
+        mock_get_client_ssl_credentials.assert_called_once()
+        mock_ssl_channel_credentials.assert_called_once_with(
+            certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+        )