feat: add mTLS ADC support for HTTP (#457)
feat: add mTLS ADC support for HTTP
diff --git a/tests/conftest.py b/tests/conftest.py
index 7f9a968..cf8a0f9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -12,12 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import sys
import mock
import pytest
+def pytest_configure():
+ """Load public certificate and private key."""
+ pytest.data_dir = os.path.join(os.path.dirname(__file__), "data")
+
+ with open(os.path.join(pytest.data_dir, "privatekey.pem"), "rb") as fh:
+ pytest.private_key_bytes = fh.read()
+
+ with open(os.path.join(pytest.data_dir, "public_cert.pem"), "rb") as fh:
+ pytest.public_cert_bytes = fh.read()
+
+
@pytest.fixture
def mock_non_existent_module(monkeypatch):
"""Mocks a non-existing module in sys.modules.
diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py
index 6e7175f..5bf1967 100644
--- a/tests/transport/test__mtls_helper.py
+++ b/tests/transport/test__mtls_helper.py
@@ -20,14 +20,6 @@
from google.auth.transport import _mtls_helper
-DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
-
-with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
- PRIVATE_KEY_BYTES = fh.read()
-
-with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
- PUBLIC_CERT_BYTES = fh.read()
-
CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]}
CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {}
@@ -49,22 +41,30 @@
def test_cert_and_key(self):
# Test single cert and single key
check_cert_and_key(
- PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES
+ pytest.public_cert_bytes + pytest.private_key_bytes,
+ pytest.public_cert_bytes,
+ pytest.private_key_bytes,
)
check_cert_and_key(
- PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES
+ pytest.private_key_bytes + pytest.public_cert_bytes,
+ pytest.public_cert_bytes,
+ pytest.private_key_bytes,
)
# Test cert chain and single key
check_cert_and_key(
- PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES,
- PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
- PRIVATE_KEY_BYTES,
+ pytest.public_cert_bytes
+ + pytest.public_cert_bytes
+ + pytest.private_key_bytes,
+ pytest.public_cert_bytes + pytest.public_cert_bytes,
+ pytest.private_key_bytes,
)
check_cert_and_key(
- PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
- PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES,
- PRIVATE_KEY_BYTES,
+ pytest.private_key_bytes
+ + pytest.public_cert_bytes
+ + pytest.public_cert_bytes,
+ pytest.public_cert_bytes + pytest.public_cert_bytes,
+ pytest.private_key_bytes,
)
def test_key(self):
@@ -82,33 +82,39 @@
/fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB
-----END EC PRIVATE KEY-----"""
- check_cert_and_key(PUBLIC_CERT_BYTES + KEY, PUBLIC_CERT_BYTES, KEY)
- check_cert_and_key(PUBLIC_CERT_BYTES + RSA_KEY, PUBLIC_CERT_BYTES, RSA_KEY)
- check_cert_and_key(PUBLIC_CERT_BYTES + EC_KEY, PUBLIC_CERT_BYTES, EC_KEY)
+ check_cert_and_key(
+ pytest.public_cert_bytes + KEY, pytest.public_cert_bytes, KEY
+ )
+ check_cert_and_key(
+ pytest.public_cert_bytes + RSA_KEY, pytest.public_cert_bytes, RSA_KEY
+ )
+ check_cert_and_key(
+ pytest.public_cert_bytes + EC_KEY, pytest.public_cert_bytes, EC_KEY
+ )
class TestCheckaMetadataPath(object):
def test_success(self):
- metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json")
+ metadata_path = os.path.join(pytest.data_dir, "context_aware_metadata.json")
returned_path = _mtls_helper._check_dca_metadata_path(metadata_path)
assert returned_path is not None
def test_failure(self):
- metadata_path = os.path.join(DATA_DIR, "not_exists.json")
+ metadata_path = os.path.join(pytest.data_dir, "not_exists.json")
returned_path = _mtls_helper._check_dca_metadata_path(metadata_path)
assert returned_path is None
class TestReadMetadataFile(object):
def test_success(self):
- metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json")
+ metadata_path = os.path.join(pytest.data_dir, "context_aware_metadata.json")
metadata = _mtls_helper._read_dca_metadata_file(metadata_path)
assert "cert_provider_command" in metadata
def test_file_not_json(self):
# read a file which is not json format.
- metadata_path = os.path.join(DATA_DIR, "privatekey.pem")
+ metadata_path = os.path.join(pytest.data_dir, "privatekey.pem")
with pytest.raises(ValueError):
_mtls_helper._read_dca_metadata_file(metadata_path)
@@ -129,21 +135,21 @@
@mock.patch("subprocess.Popen", autospec=True)
def test_success(self, mock_popen):
mock_popen.return_value = self.create_mock_process(
- PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, b""
+ pytest.public_cert_bytes + pytest.private_key_bytes, b""
)
cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
- assert cert == PUBLIC_CERT_BYTES
- assert key == PRIVATE_KEY_BYTES
+ assert cert == pytest.public_cert_bytes
+ assert key == pytest.private_key_bytes
@mock.patch("subprocess.Popen", autospec=True)
def test_success_with_cert_chain(self, mock_popen):
- PUBLIC_CERT_CHAIN_BYTES = PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES
+ PUBLIC_CERT_CHAIN_BYTES = pytest.public_cert_bytes + pytest.public_cert_bytes
mock_popen.return_value = self.create_mock_process(
- PUBLIC_CERT_CHAIN_BYTES + PRIVATE_KEY_BYTES, b""
+ PUBLIC_CERT_CHAIN_BYTES + pytest.private_key_bytes, b""
)
cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
assert cert == PUBLIC_CERT_CHAIN_BYTES
- assert key == PRIVATE_KEY_BYTES
+ assert key == pytest.private_key_bytes
def test_missing_cert_provider_command(self):
with pytest.raises(ValueError):
@@ -153,13 +159,17 @@
@mock.patch("subprocess.Popen", autospec=True)
def test_missing_cert(self, mock_popen):
- mock_popen.return_value = self.create_mock_process(PRIVATE_KEY_BYTES, b"")
+ mock_popen.return_value = self.create_mock_process(
+ pytest.private_key_bytes, b""
+ )
with pytest.raises(ValueError):
assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
@mock.patch("subprocess.Popen", autospec=True)
def test_missing_key(self, mock_popen):
- mock_popen.return_value = self.create_mock_process(PUBLIC_CERT_BYTES, b"")
+ mock_popen.return_value = self.create_mock_process(
+ pytest.public_cert_bytes, b""
+ )
with pytest.raises(ValueError):
assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
@@ -175,3 +185,45 @@
mock_popen.side_effect = OSError()
with pytest.raises(OSError):
assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA)
+
+
+class TestGetClientCertAndKey(object):
+ def test_callback_success(self):
+ callback = mock.Mock()
+ callback.return_value = (pytest.public_cert_bytes, pytest.private_key_bytes)
+
+ found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key(callback)
+ assert found_cert_key
+ assert cert == pytest.public_cert_bytes
+ assert key == pytest.private_key_bytes
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_no_metadata(self, mock_check_dca_metadata_path):
+ mock_check_dca_metadata_path.return_value = None
+
+ found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key()
+ assert not found_cert_key
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True
+ )
+ @mock.patch(
+ "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True
+ )
+ def test_use_metadata(
+ self, mock_check_dca_metadata_path, mock_get_client_ssl_credentials
+ ):
+ mock_check_dca_metadata_path.return_value = os.path.join(
+ pytest.data_dir, "context_aware_metadata.json"
+ )
+ mock_get_client_ssl_credentials.return_value = (
+ pytest.public_cert_bytes,
+ pytest.private_key_bytes,
+ )
+
+ found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key()
+ assert found_cert_key
+ assert cert == pytest.public_cert_bytes
+ assert key == pytest.private_key_bytes
diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py
index 9aafd88..3f3e14c 100644
--- a/tests/transport/test_requests.py
+++ b/tests/transport/test_requests.py
@@ -17,12 +17,14 @@
import freezegun
import mock
+import OpenSSL
import pytest
import requests
import requests.adapters
from six.moves import http_client
import google.auth.credentials
+import google.auth.transport._mtls_helper
import google.auth.transport.requests
from tests.transport import compliance
@@ -150,6 +152,34 @@
return super(TimeTickAdapterStub, self).send(request, **kwargs)
+class TestMutualTlsAdapter(object):
+ @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager")
+ @mock.patch.object(requests.adapters.HTTPAdapter, "proxy_manager_for")
+ def test_success(self, mock_proxy_manager_for, mock_init_poolmanager):
+ adapter = google.auth.transport.requests._MutualTlsAdapter(
+ pytest.public_cert_bytes, pytest.private_key_bytes
+ )
+
+ adapter.init_poolmanager()
+ mock_init_poolmanager.assert_called_with(ssl_context=adapter._ctx_poolmanager)
+
+ adapter.proxy_manager_for()
+ mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager)
+
+ def test_invalid_cert_or_key(self):
+ with pytest.raises(OpenSSL.crypto.Error):
+ google.auth.transport.requests._MutualTlsAdapter(
+ b"invalid cert", b"invalid key"
+ )
+
+ @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None})
+ def test_import_error(self):
+ with pytest.raises(ImportError):
+ google.auth.transport.requests._MutualTlsAdapter(
+ pytest.public_cert_bytes, pytest.private_key_bytes
+ )
+
+
def make_response(status=http_client.OK, data=None):
response = requests.Response()
response.status_code = status
@@ -157,7 +187,7 @@
return response
-class TestAuthorizedHttp(object):
+class TestAuthorizedSession(object):
TEST_URL = "http://example.com/"
def test_constructor(self):
@@ -326,3 +356,61 @@
authed_session.request(
"GET", self.TEST_URL, timeout=60, max_allowed_time=2.9
)
+
+ def test_configure_mtls_channel_with_callback(self):
+ mock_callback = mock.Mock()
+ mock_callback.return_value = (
+ pytest.public_cert_bytes,
+ pytest.private_key_bytes,
+ )
+
+ auth_session = google.auth.transport.requests.AuthorizedSession(
+ credentials=mock.Mock()
+ )
+ auth_session.configure_mtls_channel(mock_callback)
+
+ assert auth_session.is_mtls
+ assert isinstance(
+ auth_session.adapters["https://"],
+ google.auth.transport.requests._MutualTlsAdapter,
+ )
+
+ @mock.patch(
+ "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
+ )
+ def test_configure_mtls_channel_with_metadata(self, mock_get_client_cert_and_key):
+ mock_get_client_cert_and_key.return_value = (
+ True,
+ pytest.public_cert_bytes,
+ pytest.private_key_bytes,
+ )
+
+ auth_session = google.auth.transport.requests.AuthorizedSession(
+ credentials=mock.Mock()
+ )
+ auth_session.configure_mtls_channel()
+
+ assert auth_session.is_mtls
+ assert isinstance(
+ auth_session.adapters["https://"],
+ google.auth.transport.requests._MutualTlsAdapter,
+ )
+
+ @mock.patch.object(google.auth.transport.requests._MutualTlsAdapter, "__init__")
+ @mock.patch(
+ "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
+ )
+ def test_configure_mtls_channel_non_mtls(
+ self, mock_get_client_cert_and_key, mock_adapter_ctor
+ ):
+ mock_get_client_cert_and_key.return_value = (False, None, None)
+
+ auth_session = google.auth.transport.requests.AuthorizedSession(
+ credentials=mock.Mock()
+ )
+ auth_session.configure_mtls_channel()
+
+ assert not auth_session.is_mtls
+
+ # Assert _MutualTlsAdapter constructor is not called.
+ mock_adapter_ctor.assert_not_called()
diff --git a/tests/transport/test_urllib3.py b/tests/transport/test_urllib3.py
index 8a30733..0452e91 100644
--- a/tests/transport/test_urllib3.py
+++ b/tests/transport/test_urllib3.py
@@ -13,10 +13,13 @@
# limitations under the License.
import mock
+import OpenSSL
+import pytest
from six.moves import http_client
import urllib3
import google.auth.credentials
+import google.auth.transport._mtls_helper
import google.auth.transport.urllib3
from tests.transport import compliance
@@ -77,6 +80,27 @@
self.data = data
+class TestMakeMutualTlsHttp(object):
+ def test_success(self):
+ http = google.auth.transport.urllib3._make_mutual_tls_http(
+ pytest.public_cert_bytes, pytest.private_key_bytes
+ )
+ assert isinstance(http, urllib3.PoolManager)
+
+ def test_crypto_error(self):
+ with pytest.raises(OpenSSL.crypto.Error):
+ google.auth.transport.urllib3._make_mutual_tls_http(
+ b"invalid cert", b"invalid key"
+ )
+
+ @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None})
+ def test_import_error(self):
+ with pytest.raises(ImportError):
+ google.auth.transport.urllib3._make_mutual_tls_http(
+ pytest.public_cert_bytes, pytest.private_key_bytes
+ )
+
+
class TestAuthorizedHttp(object):
TEST_URL = "http://example.com"
@@ -138,3 +162,62 @@
authed_http.headers = mock.sentinel.headers
assert authed_http.headers == http.headers
+
+ @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True)
+ def test_configure_mtls_channel_with_callback(self, mock_make_mutual_tls_http):
+ callback = mock.Mock()
+ callback.return_value = (pytest.public_cert_bytes, pytest.private_key_bytes)
+
+ authed_http = google.auth.transport.urllib3.AuthorizedHttp(
+ credentials=mock.Mock(), http=mock.Mock()
+ )
+
+ with pytest.warns(UserWarning):
+ is_mtls = authed_http.configure_mtls_channel(callback)
+
+ assert is_mtls
+ mock_make_mutual_tls_http.assert_called_once_with(
+ cert=pytest.public_cert_bytes, key=pytest.private_key_bytes
+ )
+
+ @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True)
+ @mock.patch(
+ "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
+ )
+ def test_configure_mtls_channel_with_metadata(
+ self, mock_get_client_cert_and_key, mock_make_mutual_tls_http
+ ):
+ authed_http = google.auth.transport.urllib3.AuthorizedHttp(
+ credentials=mock.Mock()
+ )
+
+ mock_get_client_cert_and_key.return_value = (
+ True,
+ pytest.public_cert_bytes,
+ pytest.private_key_bytes,
+ )
+ is_mtls = authed_http.configure_mtls_channel()
+
+ assert is_mtls
+ mock_get_client_cert_and_key.assert_called_once()
+ mock_make_mutual_tls_http.assert_called_once_with(
+ cert=pytest.public_cert_bytes, key=pytest.private_key_bytes
+ )
+
+ @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True)
+ @mock.patch(
+ "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True
+ )
+ def test_configure_mtls_channel_non_mtls(
+ self, mock_get_client_cert_and_key, mock_make_mutual_tls_http
+ ):
+ authed_http = google.auth.transport.urllib3.AuthorizedHttp(
+ credentials=mock.Mock()
+ )
+
+ mock_get_client_cert_and_key.return_value = (False, None, None)
+ is_mtls = authed_http.configure_mtls_channel()
+
+ assert not is_mtls
+ mock_get_client_cert_and_key.assert_called_once()
+ mock_make_mutual_tls_http.assert_not_called()