feat: add SslCredentials class for mTLS ADC (#448)
feat: add SslCredentials class for mTLS ADC (linux)
diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py
new file mode 100644
index 0000000..1ce9fa5
--- /dev/null
+++ b/google/auth/transport/_mtls_helper.py
@@ -0,0 +1,116 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper functions for getting mTLS cert and key, for internal use only."""
+
+import json
+import logging
+from os import path
+import re
+import subprocess
+
+CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
+_CERT_PROVIDER_COMMAND = "cert_provider_command"
+_CERT_REGEX = re.compile(
+ b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL
+)
+
+# support various format of key files, e.g.
+# "-----BEGIN PRIVATE KEY-----...",
+# "-----BEGIN EC PRIVATE KEY-----...",
+# "-----BEGIN RSA PRIVATE KEY-----..."
+_KEY_REGEX = re.compile(
+ b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?",
+ re.DOTALL,
+)
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def _check_dca_metadata_path(metadata_path):
+ """Checks for context aware metadata. If it exists, returns the absolute path;
+ otherwise returns None.
+
+ Args:
+ metadata_path (str): context aware metadata path.
+
+ Returns:
+ str: absolute path if exists and None otherwise.
+ """
+ metadata_path = path.expanduser(metadata_path)
+ if not path.exists(metadata_path):
+ _LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path)
+ return None
+ return metadata_path
+
+
+def _read_dca_metadata_file(metadata_path):
+ """Loads context aware metadata from the given path.
+
+ Args:
+ metadata_path (str): context aware metadata path.
+
+ Returns:
+ Dict[str, str]: The metadata.
+
+ Raises:
+ ValueError: If failed to parse metadata as JSON.
+ """
+ with open(metadata_path) as f:
+ metadata = json.load(f)
+
+ return metadata
+
+
+def get_client_ssl_credentials(metadata_json):
+ """Returns the client side mTLS cert and key.
+
+ Args:
+ metadata_json (Dict[str, str]): metadata JSON file which contains the cert
+ provider command.
+
+ Returns:
+ Tuple[bytes, bytes]: client certificate and key, both in PEM format.
+
+ Raises:
+ OSError: If the cert provider command failed to run.
+ RuntimeError: If the cert provider command has a runtime error.
+ ValueError: If the metadata json file doesn't contain the cert provider command or if the command doesn't produce both the client certificate and client key.
+ """
+ # TODO: implement an in-memory cache of cert and key so we don't have to
+ # run cert provider command every time.
+
+ # Check the cert provider command existence in the metadata json file.
+ if _CERT_PROVIDER_COMMAND not in metadata_json:
+ raise ValueError("Cert provider command is not found")
+
+ # Execute the command. It throws OsError in case of system failure.
+ command = metadata_json[_CERT_PROVIDER_COMMAND]
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = process.communicate()
+
+ # Check cert provider command execution error.
+ if process.returncode != 0:
+ raise RuntimeError(
+ "Cert provider command returns non-zero status code %s" % process.returncode
+ )
+
+ # Extract certificate (chain) and key.
+ cert_match = re.findall(_CERT_REGEX, stdout)
+ if len(cert_match) != 1:
+ raise ValueError("Client SSL certificate is missing or invalid")
+ key_match = re.findall(_KEY_REGEX, stdout)
+ if len(key_match) != 1:
+ raise ValueError("Client SSL key is missing or invalid")
+ return cert_match[0], key_match[0]
diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py
index fb90fbb..ca38739 100644
--- a/google/auth/transport/grpc.py
+++ b/google/auth/transport/grpc.py
@@ -17,9 +17,12 @@
from __future__ import absolute_import
from concurrent import futures
+import logging
import six
+from google.auth.transport import _mtls_helper
+
try:
import grpc
except ImportError as caught_exc: # pragma: NO COVER
@@ -31,6 +34,8 @@
caught_exc,
)
+_LOGGER = logging.getLogger(__name__)
+
class AuthMetadataPlugin(grpc.AuthMetadataPlugin):
"""A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each
@@ -92,7 +97,12 @@
def secure_authorized_channel(
- credentials, request, target, ssl_credentials=None, **kwargs
+ credentials,
+ request,
+ target,
+ ssl_credentials=None,
+ client_cert_callback=None,
+ **kwargs
):
"""Creates a secure authorized gRPC channel.
@@ -114,11 +124,86 @@
# Create a channel.
channel = google.auth.transport.grpc.secure_authorized_channel(
- credentials, 'speech.googleapis.com:443', request)
+ credentials, regular_endpoint, request,
+ ssl_credentials=grpc.ssl_channel_credentials())
# Use the channel to create a stub.
cloud_speech.create_Speech_stub(channel)
+ Usage:
+
+ There are actually a couple of options to create a channel, depending on if
+ you want to create a regular or mutual TLS channel.
+
+ First let's list the endpoints (regular vs mutual TLS) to choose from::
+
+ regular_endpoint = 'speech.googleapis.com:443'
+ mtls_endpoint = 'speech.mtls.googleapis.com:443'
+
+ Option 1: create a regular (non-mutual) TLS channel by explicitly setting
+ the ssl_credentials::
+
+ regular_ssl_credentials = grpc.ssl_channel_credentials()
+
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, regular_endpoint, request,
+ ssl_credentials=regular_ssl_credentials)
+
+ Option 2: create a mutual TLS channel by calling a callback which returns
+ the client side certificate and the key::
+
+ def my_client_cert_callback():
+ code_to_load_client_cert_and_key()
+ if loaded:
+ return (pem_cert_bytes, pem_key_bytes)
+ raise MyClientCertFailureException()
+
+ try:
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, mtls_endpoint, request,
+ client_cert_callback=my_client_cert_callback)
+ except MyClientCertFailureException:
+ # handle the exception
+
+ Option 3: use application default SSL credentials. It searches and uses
+ the command in a context aware metadata file, which is available on devices
+ with endpoint verification support.
+ See https://cloud.google.com/endpoint-verification/docs/overview::
+
+ try:
+ default_ssl_credentials = SslCredentials()
+ except:
+ # Exception can be raised if the context aware metadata is malformed.
+ # See :class:`SslCredentials` for the possible exceptions.
+
+ # Choose the endpoint based on the SSL credentials type.
+ if default_ssl_credentials.is_mtls:
+ endpoint_to_use = mtls_endpoint
+ else:
+ endpoint_to_use = regular_endpoint
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, endpoint_to_use, request,
+ ssl_credentials=default_ssl_credentials)
+
+ Option 4: not setting ssl_credentials and client_cert_callback. For devices
+ without endpoint verification support, a regular TLS channel is created;
+ otherwise, a mutual TLS channel is created, however, the call should be
+ wrapped in a try/except block in case of malformed context aware metadata.
+
+ The following code uses regular_endpoint, it works the same no matter the
+ created channle is regular or mutual TLS. Regular endpoint ignores client
+ certificate and key::
+
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, regular_endpoint, request)
+
+ The following code uses mtls_endpoint, if the created channle is regular,
+ and API mtls_endpoint is confgured to require client SSL credentials, API
+ calls using this channel will be rejected::
+
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, mtls_endpoint, request)
+
Args:
credentials (google.auth.credentials.Credentials): The credentials to
add to requests.
@@ -129,10 +214,33 @@
target (str): The host and port of the service.
ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
credentials. This can be used to specify different certificates.
+ This argument is mutually exclusive with client_cert_callback;
+ providing both will raise an exception.
+ If ssl_credentials and client_cert_callback are None, application
+ default SSL credentials will be used.
+ client_cert_callback (Callable[[], (bytes, bytes)]): Optional
+ callback function to obtain client certicate and key for mutual TLS
+ connection. This argument is mutually exclusive with
+ ssl_credentials; providing both will raise an exception.
+ If ssl_credentials and client_cert_callback are None, application
+ default SSL credentials will be used.
kwargs: Additional arguments to pass to :func:`grpc.secure_channel`.
Returns:
grpc.Channel: The created gRPC channel.
+
+ Raises:
+ OSError: If the cert provider command launch fails during the application
+ default SSL credentials loading process on devices with endpoint
+ verification support.
+ RuntimeError: If the cert provider command has a runtime error during the
+ application default SSL credentials loading process on devices with
+ endpoint verification support.
+ ValueError:
+ If the context aware metadata file is malformed or if the cert provider
+ command doesn't produce both client certificate and key during the
+ application default SSL credentials loading process on devices with
+ endpoint verification support.
"""
# Create the metadata plugin for inserting the authorization header.
metadata_plugin = AuthMetadataPlugin(credentials, request)
@@ -140,8 +248,24 @@
# Create a set of grpc.CallCredentials using the metadata plugin.
google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
- if ssl_credentials is None:
- ssl_credentials = grpc.ssl_channel_credentials()
+ if ssl_credentials and client_cert_callback:
+ raise ValueError(
+ "Received both ssl_credentials and client_cert_callback; "
+ "these are mutually exclusive."
+ )
+
+ # If SSL credentials are not explicitly set, try client_cert_callback and ADC.
+ if not ssl_credentials:
+ if client_cert_callback:
+ # Use the callback if provided.
+ cert, key = client_cert_callback()
+ ssl_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ # Use application default SSL credentials.
+ adc_ssl_credentils = SslCredentials()
+ ssl_credentials = adc_ssl_credentils.ssl_credentials
# Combine the ssl credentials and the authorization credentials.
composite_credentials = grpc.composite_channel_credentials(
@@ -149,3 +273,59 @@
)
return grpc.secure_channel(target, composite_credentials, **kwargs)
+
+
+class SslCredentials:
+ """Class for application default SSL credentials.
+
+ For devices with endpoint verification support, a device certificate will be
+ automatically loaded and mutual TLS will be established.
+ See https://cloud.google.com/endpoint-verification/docs/overview.
+ """
+
+ def __init__(self):
+ # Load client SSL credentials.
+ self._context_aware_metadata_path = _mtls_helper._check_dca_metadata_path(
+ _mtls_helper.CONTEXT_AWARE_METADATA_PATH
+ )
+ if self._context_aware_metadata_path:
+ self._is_mtls = True
+ else:
+ self._is_mtls = False
+
+ @property
+ def ssl_credentials(self):
+ """Get the created SSL channel credentials.
+
+ For devices with endpoint verification support, if the device certificate
+ loading has any problems, corresponding exceptions will be raised. For
+ a device without endpoint verification support, no exceptions will be
+ raised.
+
+ Returns:
+ grpc.ChannelCredentials: The created grpc channel credentials.
+
+ Raises:
+ OSError: If the cert provider command launch fails.
+ RuntimeError: If the cert provider command has a runtime error.
+ ValueError:
+ If the context aware metadata file is malformed or if the cert provider
+ command doesn't produce both the client certificate and key.
+ """
+ if self._context_aware_metadata_path:
+ metadata = _mtls_helper._read_dca_metadata_file(
+ self._context_aware_metadata_path
+ )
+ cert, key = _mtls_helper.get_client_ssl_credentials(metadata)
+ self._ssl_credentials = grpc.ssl_channel_credentials(
+ certificate_chain=cert, private_key=key
+ )
+ else:
+ self._ssl_credentials = grpc.ssl_channel_credentials()
+
+ return self._ssl_credentials
+
+ @property
+ def is_mtls(self):
+ """Indicates if the created SSL channel credentials is mutual TLS."""
+ return self._is_mtls
diff --git a/tests/data/context_aware_metadata.json b/tests/data/context_aware_metadata.json
new file mode 100644
index 0000000..ec40e78
--- /dev/null
+++ b/tests/data/context_aware_metadata.json
@@ -0,0 +1,6 @@
+{
+ "cert_provider_command":[
+ "/opt/google/endpoint-verification/bin/SecureConnectHelper",
+ "--print_certificate"],
+ "device_resource_ids":["11111111-1111-1111"]
+}
diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py
new file mode 100644
index 0000000..6e7175f
--- /dev/null
+++ b/tests/transport/test__mtls_helper.py
@@ -0,0 +1,177 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+
+import mock
+import pytest
+
+from google.auth.transport import _mtls_helper
+
+DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
+
+with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
+ PRIVATE_KEY_BYTES = fh.read()
+
+with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
+ PUBLIC_CERT_BYTES = fh.read()
+
+CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]}
+
+CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {}
+
+
+def check_cert_and_key(content, expected_cert, expected_key):
+ success = True
+
+ cert_match = re.findall(_mtls_helper._CERT_REGEX, content)
+ success = success and len(cert_match) == 1 and cert_match[0] == expected_cert
+
+ key_match = re.findall(_mtls_helper._KEY_REGEX, content)
+ success = success and len(key_match) == 1 and key_match[0] == expected_key
+
+ return success
+
+
+class TestCertAndKeyRegex(object):
+ def test_cert_and_key(self):
+ # Test single cert and single key
+ check_cert_and_key(
+ PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES
+ )
+ check_cert_and_key(
+ PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES
+ )
+
+ # Test cert chain and single key
+ check_cert_and_key(
+ PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES,
+ PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
+ PRIVATE_KEY_BYTES,
+ )
+ check_cert_and_key(
+ PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
+ PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
+ PRIVATE_KEY_BYTES,
+ )
+
+ def test_key(self):
+ # Create some fake keys for regex check.
+ KEY = b"""-----BEGIN PRIVATE KEY-----
+ MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg
+ /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB
+ -----END PRIVATE KEY-----"""
+ RSA_KEY = b"""-----BEGIN RSA PRIVATE KEY-----
+ MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg
+ /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB
+ -----END RSA PRIVATE KEY-----"""
+ EC_KEY = b"""-----BEGIN EC PRIVATE KEY-----
+ MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg
+ /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB
+ -----END EC PRIVATE KEY-----"""
+
+ check_cert_and_key(PUBLIC_CERT_BYTES + KEY, PUBLIC_CERT_BYTES, KEY)
+ check_cert_and_key(PUBLIC_CERT_BYTES + RSA_KEY, PUBLIC_CERT_BYTES, RSA_KEY)
+ check_cert_and_key(PUBLIC_CERT_BYTES + EC_KEY, PUBLIC_CERT_BYTES, EC_KEY)
+
+
+class TestCheckaMetadataPath(object):
+ def test_success(self):
+ metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json")
+ returned_path = _mtls_helper._check_dca_metadata_path(metadata_path)
+ assert returned_path is not None
+
+ def test_failure(self):
+ metadata_path = os.path.join(DATA_DIR, "not_exists.json")
+ returned_path = _mtls_helper._check_dca_metadata_path(metadata_path)
+ assert returned_path is None
+
+
+class TestReadMetadataFile(object):
+ def test_success(self):
+ metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json")
+ metadata = _mtls_helper._read_dca_metadata_file(metadata_path)
+
+ assert "cert_provider_command" in metadata
+
+ def test_file_not_json(self):
+ # read a file which is not json format.
+ metadata_path = os.path.join(DATA_DIR, "privatekey.pem")
+ with pytest.raises(ValueError):
+ _mtls_helper._read_dca_metadata_file(metadata_path)
+
+
+class TestGetClientSslCredentials(object):
+ def create_mock_process(self, output, error):
+ # There are two steps to execute a script with subprocess.Popen.
+ # (1) process = subprocess.Popen([comannds])
+ # (2) stdout, stderr = process.communicate()
+ # This function creates a mock process which can be returned by a mock
+ # subprocess.Popen. The mock process returns the given output and error
+ # when mock_process.communicate() is called.
+ mock_process = mock.Mock()
+ attrs = {"communicate.return_value": (output, error), "returncode": 0}
+ mock_process.configure_mock(**attrs)
+ return mock_process
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_success(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(
+ PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, b""
+ )
+ cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ assert cert == PUBLIC_CERT_BYTES
+ assert key == PRIVATE_KEY_BYTES
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_success_with_cert_chain(self, mock_popen):
+ PUBLIC_CERT_CHAIN_BYTES = PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES
+ mock_popen.return_value = self.create_mock_process(
+ PUBLIC_CERT_CHAIN_BYTES + PRIVATE_KEY_BYTES, b""
+ )
+ cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+ assert cert == PUBLIC_CERT_CHAIN_BYTES
+ assert key == PRIVATE_KEY_BYTES
+
+ def test_missing_cert_provider_command(self):
+ with pytest.raises(ValueError):
+ assert _mtls_helper.get_client_ssl_credentials(
+ CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND
+ )
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_missing_cert(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(PRIVATE_KEY_BYTES, b"")
+ with pytest.raises(ValueError):
+ assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_missing_key(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(PUBLIC_CERT_BYTES, b"")
+ with pytest.raises(ValueError):
+ assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_cert_provider_returns_error(self, mock_popen):
+ mock_popen.return_value = self.create_mock_process(b"", b"some error")
+ mock_popen.return_value.returncode = 1
+ with pytest.raises(RuntimeError):
+ assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+
+ @mock.patch("subprocess.Popen", autospec=True)
+ def test_popen_raise_exception(self, mock_popen):
+ mock_popen.side_effect = OSError()
+ with pytest.raises(OSError):
+ assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py
index 857c32b..23e62a2 100644
--- a/tests/transport/test_grpc.py
+++ b/tests/transport/test_grpc.py
@@ -13,6 +13,7 @@
# limitations under the License.
import datetime
+import os
import time
import mock
@@ -31,6 +32,12 @@
except ImportError: # pragma: NO COVER
HAS_GRPC = False
+DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
+METADATA_PATH = os.path.join(DATA_DIR, "context_aware_metadata.json")
+with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
+ PRIVATE_KEY_BYTES = fh.read()
+with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
+ PUBLIC_CERT_BYTES = fh.read()
pytestmark = pytest.mark.skipif(not HAS_GRPC, reason="gRPC is unavailable.")
@@ -87,70 +94,251 @@
)
+@mock.patch(
+ "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+)
@mock.patch("grpc.composite_channel_credentials", autospec=True)
@mock.patch("grpc.metadata_call_credentials", autospec=True)
@mock.patch("grpc.ssl_channel_credentials", autospec=True)
@mock.patch("grpc.secure_channel", autospec=True)
-def test_secure_authorized_channel(
- secure_channel,
- ssl_channel_credentials,
- metadata_call_credentials,
- composite_channel_credentials,
-):
- credentials = CredentialsStub()
- request = mock.create_autospec(transport.Request)
- target = "example.com:80"
-
- channel = google.auth.transport.grpc.secure_authorized_channel(
- credentials, request, target, options=mock.sentinel.options
+class TestSecureAuthorizedChannel(object):
+ @mock.patch(
+ "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
)
-
- # Check the auth plugin construction.
- auth_plugin = metadata_call_credentials.call_args[0][0]
- assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
- assert auth_plugin._credentials == credentials
- assert auth_plugin._request == request
-
- # Check the ssl channel call.
- assert ssl_channel_credentials.called
-
- # Check the composite credentials call.
- composite_channel_credentials.assert_called_once_with(
- ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
)
+ def test_secure_authorized_channel_adc(
+ self,
+ check_dca_metadata_path,
+ read_dca_metadata_file,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = CredentialsStub()
+ request = mock.create_autospec(transport.Request)
+ target = "example.com:80"
- # Check the channel call.
- secure_channel.assert_called_once_with(
- target,
- composite_channel_credentials.return_value,
- options=mock.sentinel.options,
+ # Mock the context aware metadata and client cert/key so mTLS SSL channel
+ # will be used.
+ check_dca_metadata_path.return_value = METADATA_PATH
+ read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["some command"]
+ }
+ get_client_ssl_credentials.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
+
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, options=mock.sentinel.options
+ )
+
+ # Check the auth plugin construction.
+ auth_plugin = metadata_call_credentials.call_args[0][0]
+ assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin)
+ assert auth_plugin._credentials == credentials
+ assert auth_plugin._request == request
+
+ # Check the ssl channel call.
+ ssl_channel_credentials.assert_called_once_with(
+ certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+ )
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+ )
+
+ # Check the channel call.
+ secure_channel.assert_called_once_with(
+ target,
+ composite_channel_credentials.return_value,
+ options=mock.sentinel.options,
+ )
+ assert channel == secure_channel.return_value
+
+ def test_secure_authorized_channel_explicit_ssl(
+ self,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+ ssl_credentials = mock.Mock()
+
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, ssl_credentials=ssl_credentials
+ )
+
+ # Since explicit SSL credentials are provided, get_client_ssl_credentials
+ # shouldn't be called.
+ assert not get_client_ssl_credentials.called
+
+ # Check the ssl channel call.
+ assert not ssl_channel_credentials.called
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_credentials, metadata_call_credentials.return_value
+ )
+
+ def test_secure_authorized_channel_mutual_exclusive(
+ self,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+ ssl_credentials = mock.Mock()
+ client_cert_callback = mock.Mock()
+
+ with pytest.raises(ValueError):
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials,
+ request,
+ target,
+ ssl_credentials=ssl_credentials,
+ client_cert_callback=client_cert_callback,
+ )
+
+ def test_secure_authorized_channel_with_client_cert_callback_success(
+ self,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+ client_cert_callback = mock.Mock()
+ client_cert_callback.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES)
+
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, client_cert_callback=client_cert_callback
+ )
+
+ client_cert_callback.assert_called_once()
+
+ # Check we are using the cert and key provided by client_cert_callback.
+ ssl_channel_credentials.assert_called_once_with(
+ certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+ )
+
+ # Check the composite credentials call.
+ composite_channel_credentials.assert_called_once_with(
+ ssl_channel_credentials.return_value, metadata_call_credentials.return_value
+ )
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True
)
- assert channel == secure_channel.return_value
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_secure_authorized_channel_with_client_cert_callback_failure(
+ self,
+ check_dca_metadata_path,
+ read_dca_metadata_file,
+ secure_channel,
+ ssl_channel_credentials,
+ metadata_call_credentials,
+ composite_channel_credentials,
+ get_client_ssl_credentials,
+ ):
+ credentials = mock.Mock()
+ request = mock.Mock()
+ target = "example.com:80"
+
+ client_cert_callback = mock.Mock()
+ client_cert_callback.side_effect = Exception("callback exception")
+
+ with pytest.raises(Exception) as excinfo:
+ google.auth.transport.grpc.secure_authorized_channel(
+ credentials, request, target, client_cert_callback=client_cert_callback
+ )
+
+ assert str(excinfo.value) == "callback exception"
-@mock.patch("grpc.composite_channel_credentials", autospec=True)
-@mock.patch("grpc.metadata_call_credentials", autospec=True)
@mock.patch("grpc.ssl_channel_credentials", autospec=True)
-@mock.patch("grpc.secure_channel", autospec=True)
-def test_secure_authorized_channel_explicit_ssl(
- secure_channel,
- ssl_channel_credentials,
- metadata_call_credentials,
- composite_channel_credentials,
-):
- credentials = mock.Mock()
- request = mock.Mock()
- target = "example.com:80"
- ssl_credentials = mock.Mock()
+@mock.patch(
+ "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+)
+@mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True)
+@mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+)
+class TestSslCredentials(object):
+ def test_no_context_aware_metadata(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_get_client_ssl_credentials,
+ mock_ssl_channel_credentials,
+ ):
+ # Mock that the metadata file doesn't exist.
+ mock_check_dca_metadata_path.return_value = None
- google.auth.transport.grpc.secure_authorized_channel(
- credentials, request, target, ssl_credentials=ssl_credentials
- )
+ ssl_credentials = google.auth.transport.grpc.SslCredentials()
- # Check the ssl channel call.
- assert not ssl_channel_credentials.called
+ # Since no context aware metadata is found, we wouldn't call
+ # get_client_ssl_credentials, and the SSL channel credentials created is
+ # non mTLS.
+ assert ssl_credentials.ssl_credentials is not None
+ assert not ssl_credentials.is_mtls
+ mock_get_client_ssl_credentials.assert_not_called()
+ mock_ssl_channel_credentials.assert_called_once_with()
- # Check the composite credentials call.
- composite_channel_credentials.assert_called_once_with(
- ssl_credentials, metadata_call_credentials.return_value
- )
+ def test_get_client_ssl_credentials_failure(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_get_client_ssl_credentials,
+ mock_ssl_channel_credentials,
+ ):
+ mock_check_dca_metadata_path.return_value = METADATA_PATH
+ mock_read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["some command"]
+ }
+
+ # Mock that client cert and key are not loaded and exception is raised.
+ mock_get_client_ssl_credentials.side_effect = ValueError()
+
+ with pytest.raises(ValueError):
+ assert google.auth.transport.grpc.SslCredentials().ssl_credentials
+
+ def test_get_client_ssl_credentials_success(
+ self,
+ mock_check_dca_metadata_path,
+ mock_read_dca_metadata_file,
+ mock_get_client_ssl_credentials,
+ mock_ssl_channel_credentials,
+ ):
+ mock_check_dca_metadata_path.return_value = METADATA_PATH
+ mock_read_dca_metadata_file.return_value = {
+ "cert_provider_command": ["some command"]
+ }
+ mock_get_client_ssl_credentials.return_value = (
+ PUBLIC_CERT_BYTES,
+ PRIVATE_KEY_BYTES,
+ )
+
+ ssl_credentials = google.auth.transport.grpc.SslCredentials()
+
+ assert ssl_credentials.ssl_credentials is not None
+ assert ssl_credentials.is_mtls
+ mock_get_client_ssl_credentials.assert_called_once()
+ mock_ssl_channel_credentials.assert_called_once_with(
+ certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES
+ )