Fix gRPC to call credentials.before_request (#116)
diff --git a/google/auth/credentials.py b/google/auth/credentials.py
index 2358b1d..8570957 100644
--- a/google/auth/credentials.py
+++ b/google/auth/credentials.py
@@ -104,8 +104,9 @@
Args:
request (google.auth.transport.Request): The object used to make
HTTP requests.
- method (str): The request's HTTP method.
- url (str): The request's URI.
+ method (str): The request's HTTP method or the RPC method being
+ invoked.
+ url (str): The request's URI or the RPC service's URI.
headers (Mapping): The request's headers.
"""
# pylint: disable=unused-argument
diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py
index e6a5eb7..81d5658 100644
--- a/google/auth/transport/grpc.py
+++ b/google/auth/transport/grpc.py
@@ -17,6 +17,7 @@
from __future__ import absolute_import
import grpc
+import six
class AuthMetadataPlugin(grpc.AuthMetadataPlugin):
@@ -40,19 +41,21 @@
self._credentials = credentials
self._request = request
- def _get_authorization_headers(self):
+ def _get_authorization_headers(self, context):
"""Gets the authorization headers for a request.
Returns:
Sequence[Tuple[str, str]]: A list of request headers (key, value)
to add to the request.
"""
- if self._credentials.expired or not self._credentials.valid:
- self._credentials.refresh(self._request)
+ headers = {}
+ self._credentials.before_request(
+ self._request,
+ context.method_name,
+ context.service_url,
+ headers)
- return [
- ('authorization', 'Bearer {}'.format(self._credentials.token))
- ]
+ return list(six.iteritems(headers))
def __call__(self, context, callback):
"""Passes authorization metadata into the given callback.
@@ -62,7 +65,7 @@
callback (grpc.AuthMetadataPluginCallback): The callback that will
be invoked to pass in the authorization metadata.
"""
- callback(self._get_authorization_headers(), None)
+ callback(self._get_authorization_headers(context), None)
def secure_authorized_channel(
diff --git a/system_tests/test_grpc.py b/system_tests/test_grpc.py
index 7d436c5..73467fe 100644
--- a/system_tests/test_grpc.py
+++ b/system_tests/test_grpc.py
@@ -15,23 +15,40 @@
import google.auth
import google.auth.credentials
import google.auth.transport.grpc
-from google.cloud.gapic.pubsub.v1 import publisher_api
+from google.cloud.gapic.pubsub.v1 import publisher_client
-def test_grpc_request(http_request):
+def test_grpc_request_with_regular_credentials(http_request):
credentials, project_id = google.auth.default()
credentials = google.auth.credentials.with_scopes_if_required(
credentials, ['https://www.googleapis.com/auth/pubsub'])
- target = '{}:{}'.format(
- publisher_api.PublisherApi.SERVICE_ADDRESS,
- publisher_api.PublisherApi.DEFAULT_SERVICE_PORT)
-
channel = google.auth.transport.grpc.secure_authorized_channel(
- credentials, http_request, target)
+ credentials,
+ http_request,
+ publisher_client.PublisherClient.SERVICE_ADDRESS)
# Create a pub/sub client.
- client = publisher_api.PublisherApi(channel=channel)
+ client = publisher_client.PublisherClient(channel=channel)
+
+ # list the topics and drain the iterator to test that an authorized API
+ # call works.
+ list_topics_iter = client.list_topics(
+ project='projects/{}'.format(project_id))
+ list(list_topics_iter)
+
+
+def test_grpc_request_with_jwt_credentials(http_request):
+ credentials, project_id = google.auth.default()
+ credentials = credentials.to_jwt_credentials()
+
+ channel = google.auth.transport.grpc.secure_authorized_channel(
+ credentials,
+ http_request,
+ publisher_client.PublisherClient.SERVICE_ADDRESS)
+
+ # Create a pub/sub client.
+ client = publisher_client.PublisherClient(channel=channel)
# list the topics and drain the iterator to test that an authorized API
# call works.
diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py
index 15a301f..7a3cc0a 100644
--- a/tests/transport/test_grpc.py
+++ b/tests/transport/test_grpc.py
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import mock
+import datetime
+import mock
import pytest
+from google.auth import credentials
try:
import google.auth.transport.grpc
HAS_GRPC = True
@@ -26,11 +28,11 @@
pytestmark = pytest.mark.skipif(not HAS_GRPC, reason='gRPC is unavailable.')
-class MockCredentials(object):
+class MockCredentials(credentials.Credentials):
def __init__(self, token='token'):
+ super(MockCredentials, self).__init__()
self.token = token
- self.valid = True
- self.expired = False
+ self.expiry = None
def refresh(self, request):
self.token += '1'
@@ -54,7 +56,7 @@
def test_call_refresh(self):
credentials = MockCredentials()
- credentials.expired = True
+ credentials.expiry = datetime.datetime.min
request = mock.Mock()
plugin = google.auth.transport.grpc.AuthMetadataPlugin(
diff --git a/tox.ini b/tox.ini
index 0a98945..a2f7362 100644
--- a/tox.ini
+++ b/tox.ini
@@ -33,7 +33,7 @@
deps =
{[testenv]deps}
nox-automation
- gapic-google-pubsub-v1==0.11.1
+ gapic-google-cloud-pubsub-v1==0.15.0
passenv =
SKIP_APP_ENGINE_SYSTEM_TEST
CLOUD_SDK_ROOT
@@ -46,7 +46,7 @@
deps =
{[testenv]deps}
nox-automation
- gapic-google-pubsub-v1==0.11.1
+ gapic-google-cloud-pubsub-v1==0.15.0
passenv =
SKIP_APP_ENGINE_SYSTEM_TEST
CLOUD_SDK_ROOT