Merge pull request #6254 from grpc/python_per_rpc_interop

Added google call creds/per_rpc interop tests
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
index c793c8f..19a59e0 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pxd.pxi
@@ -68,4 +68,4 @@
     void *state, grpc_auth_metadata_context context,
     grpc_credentials_plugin_metadata_cb cb, void *user_data) with gil
 
-cdef void plugin_destroy_c_plugin_state(void *state)
+cdef void plugin_destroy_c_plugin_state(void *state) with gil
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
index 94d13b5..1ba8645 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx.pxi
@@ -137,7 +137,7 @@
   cy_context.context = context
   self.plugin_callback(cy_context, python_callback)
 
-cdef void plugin_destroy_c_plugin_state(void *state):
+cdef void plugin_destroy_c_plugin_state(void *state) with gil:
   cpython.Py_DECREF(<CredentialsMetadataPlugin>state)
 
 def channel_credentials_google_default():
diff --git a/src/python/grpcio/grpc/beta/_auth.py b/src/python/grpcio/grpc/beta/_auth.py
new file mode 100644
index 0000000..553d4b9
--- /dev/null
+++ b/src/python/grpcio/grpc/beta/_auth.py
@@ -0,0 +1,73 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+#     * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#     * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+#     * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""GRPCAuthMetadataPlugins for standard authentication."""
+
+from concurrent import futures
+
+from grpc.beta import interfaces
+
+
+def _sign_request(callback, token, error):
+  metadata = (('authorization', 'Bearer {}'.format(token)),)
+  callback(metadata, error)
+
+
+class GoogleCallCredentials(interfaces.GRPCAuthMetadataPlugin):
+  """Metadata wrapper for GoogleCredentials from the oauth2client library."""
+
+  def __init__(self, credentials):
+    self._credentials = credentials
+    self._pool = futures.ThreadPoolExecutor(max_workers=1)
+
+  def __call__(self, context, callback):
+    # MetadataPlugins cannot block (see grpc.beta.interfaces.py)
+    future = self._pool.submit(self._credentials.get_access_token)
+    future.add_done_callback(lambda x: self._get_token_callback(callback, x))
+
+  def _get_token_callback(self, callback, future):
+    try:
+      access_token = future.result().access_token
+    except Exception as e:
+      _sign_request(callback, None, e)
+    else:
+      _sign_request(callback, access_token, None)
+
+  def __del__(self):
+    self._pool.shutdown(wait=False)
+
+
+class AccessTokenCallCredentials(interfaces.GRPCAuthMetadataPlugin):
+  """Metadata wrapper for raw access token credentials."""
+
+  def __init__(self, access_token):
+    self._access_token = access_token
+
+  def __call__(self, context, callback):
+    _sign_request(callback, self._access_token, None)
diff --git a/src/python/grpcio/grpc/beta/implementations.py b/src/python/grpcio/grpc/beta/implementations.py
index 822f593..d8c32dd 100644
--- a/src/python/grpcio/grpc/beta/implementations.py
+++ b/src/python/grpcio/grpc/beta/implementations.py
@@ -38,6 +38,7 @@
 from grpc._adapter import _intermediary_low
 from grpc._adapter import _low
 from grpc._adapter import _types
+from grpc.beta import _auth
 from grpc.beta import _connectivity_channel
 from grpc.beta import _server
 from grpc.beta import _stub
@@ -105,10 +106,40 @@
     A CallCredentials object for use in a GRPCCallOptions object.
   """
   if name is None:
-    name = metadata_plugin.__name__
+    try:
+      name = metadata_plugin.__name__
+    except AttributeError:
+      name = metadata_plugin.__class__.__name__
   return CallCredentials(
       _low.call_credentials_metadata_plugin(metadata_plugin, name))
 
+
+def google_call_credentials(credentials):
+  """Construct CallCredentials from GoogleCredentials.
+
+  Args:
+    credentials: A GoogleCredentials object from the oauth2client library.
+
+  Returns:
+    A CallCredentials object for use in a GRPCCallOptions object.
+  """
+  return metadata_call_credentials(_auth.GoogleCallCredentials(credentials))
+
+
+def access_token_call_credentials(access_token):
+  """Construct CallCredentials from an access token.
+
+  Args:
+    access_token: A string to place directly in the http request
+      authorization header, ie "Authorization: Bearer <access_token>".
+
+  Returns:
+    A CallCredentials object for use in a GRPCCallOptions object.
+  """
+  return metadata_call_credentials(
+      _auth.AccessTokenCallCredentials(access_token))
+
+
 def composite_call_credentials(call_credentials, additional_call_credentials):
   """Compose two CallCredentials to make a new one.
 
diff --git a/src/python/grpcio/tests/interop/client.py b/src/python/grpcio/tests/interop/client.py
index db29eb4..e3d5545 100644
--- a/src/python/grpcio/tests/interop/client.py
+++ b/src/python/grpcio/tests/interop/client.py
@@ -65,39 +65,34 @@
       help='email address of the default service account', type=str)
   return parser.parse_args()
 
-def _oauth_access_token(args):
-  credentials = oauth2client_client.GoogleCredentials.get_application_default()
-  scoped_credentials = credentials.create_scoped([args.oauth_scope])
-  return scoped_credentials.get_access_token().access_token
 
 def _stub(args):
-  if args.oauth_scope:
-    if args.test_case == 'oauth2_auth_token':
-      # TODO(jtattermusch): This testcase sets the auth metadata key-value
-      # manually, which also means that the user would need to do the same
-      # thing every time he/she would like to use and out of band oauth token.
-      # The transformer function that produces the metadata key-value from
-      # the access token should be provided by gRPC auth library.
-      access_token = _oauth_access_token(args)
-      metadata_transformer = lambda x: [
-          ('authorization', 'Bearer %s' % access_token)]
-    else:
-      metadata_transformer = lambda x: [
-          ('authorization', 'Bearer %s' % _oauth_access_token(args))]
+  if args.test_case == 'oauth2_auth_token':
+    creds = oauth2client_client.GoogleCredentials.get_application_default()
+    scoped_creds = creds.create_scoped([args.oauth_scope])
+    access_token = scoped_creds.get_access_token().access_token
+    call_creds = implementations.access_token_call_credentials(access_token)
+  elif args.test_case == 'compute_engine_creds':
+    creds = oauth2client_client.GoogleCredentials.get_application_default()
+    scoped_creds = creds.create_scoped([args.oauth_scope])
+    call_creds = implementations.google_call_credentials(scoped_creds)
   else:
-    metadata_transformer = lambda x: []
+    call_creds = None
   if args.use_tls:
     if args.use_test_ca:
       root_certificates = resources.test_root_certificates()
     else:
       root_certificates = None  # will load default roots.
 
+    channel_creds = implementations.ssl_channel_credentials(root_certificates)
+    if call_creds is not None:
+      channel_creds = implementations.composite_channel_credentials(
+          channel_creds, call_creds)
+
     channel = test_utilities.not_really_secure_channel(
-        args.server_host, args.server_port,
-        implementations.ssl_channel_credentials(root_certificates),
+        args.server_host, args.server_port, channel_creds,
         args.server_host_override)
-    stub = test_pb2.beta_create_TestService_stub(
-        channel, metadata_transformer=metadata_transformer)
+    stub = test_pb2.beta_create_TestService_stub(channel)
   else:
     channel = implementations.insecure_channel(
         args.server_host, args.server_port)
diff --git a/src/python/grpcio/tests/interop/methods.py b/src/python/grpcio/tests/interop/methods.py
index 67862ed..d5ef0c6 100644
--- a/src/python/grpcio/tests/interop/methods.py
+++ b/src/python/grpcio/tests/interop/methods.py
@@ -39,6 +39,8 @@
 
 from oauth2client import client as oauth2client_client
 
+from grpc.beta import implementations
+from grpc.beta import interfaces
 from grpc.framework.common import cardinality
 from grpc.framework.interfaces.face import face
 
@@ -88,13 +90,15 @@
     return self.FullDuplexCall(request_iterator, context)
 
 
-def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope):
+def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
+                                 protocol_options=None):
   with stub:
     request = messages_pb2.SimpleRequest(
         response_type=messages_pb2.COMPRESSABLE, response_size=314159,
         payload=messages_pb2.Payload(body=b'\x00' * 271828),
         fill_username=fill_username, fill_oauth_scope=fill_oauth_scope)
-    response_future = stub.UnaryCall.future(request, _TIMEOUT)
+    response_future = stub.UnaryCall.future(request, _TIMEOUT,
+                                            protocol_options=protocol_options)
     response = response_future.result()
     if response.payload.type is not messages_pb2.COMPRESSABLE:
       raise ValueError(
@@ -303,7 +307,24 @@
   if args.oauth_scope.find(response.oauth_scope) == -1:
     raise ValueError(
         'expected to find oauth scope "%s" in received "%s"' %
-            (response.oauth_scope, args.oauth_scope))
+        (response.oauth_scope, args.oauth_scope))
+
+
+def _per_rpc_creds(stub, args):
+  json_key_filename = os.environ[
+      oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+  wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+  credentials = oauth2client_client.GoogleCredentials.get_application_default()
+  scoped_credentials = credentials.create_scoped([args.oauth_scope])
+  call_creds = implementations.google_call_credentials(scoped_credentials)
+  options = interfaces.grpc_call_options(disable_compression=False,
+                                         credentials=call_creds)
+  response = _large_unary_common_behavior(stub, True, False,
+                                          protocol_options=options)
+  if wanted_email != response.username:
+    raise ValueError(
+        'expected username %s, got %s' % (wanted_email, response.username))
+
 
 @enum.unique
 class TestCase(enum.Enum):
@@ -317,6 +338,7 @@
   EMPTY_STREAM = 'empty_stream'
   COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
   OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
+  PER_RPC_CREDS = 'per_rpc_creds'
   TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
 
   def test_interoperability(self, stub, args):
@@ -342,5 +364,7 @@
       _compute_engine_creds(stub, args)
     elif self is TestCase.OAUTH2_AUTH_TOKEN:
       _oauth2_auth_token(stub, args)
+    elif self is TestCase.PER_RPC_CREDS:
+      _per_rpc_creds(stub, args)
     else:
       raise NotImplementedError('Test case "%s" not implemented!' % self.name)
diff --git a/src/python/grpcio/tests/tests.json b/src/python/grpcio/tests/tests.json
index fb357ea..81458b1 100644
--- a/src/python/grpcio/tests/tests.json
+++ b/src/python/grpcio/tests/tests.json
@@ -1,4 +1,6 @@
 [
+  "_auth_test.AccessTokenCallCredentialsTest",
+  "_auth_test.GoogleCallCredentialsTest",
   "_base_interface_test.AsyncEasyTest", 
   "_base_interface_test.AsyncPeasyTest", 
   "_base_interface_test.SyncEasyTest", 
@@ -33,6 +35,7 @@
   "_face_interface_test.MultiCallableInvokerBlockingInvocationInlineServiceTest", 
   "_face_interface_test.MultiCallableInvokerFutureInvocationAsynchronousEventServiceTest",
   "_health_servicer_test.HealthServicerTest",
+  "_implementations_test.CallCredentialsTest",
   "_implementations_test.ChannelCredentialsTest", 
   "_insecure_interop_test.InsecureInteropTest", 
   "_intermediary_low_test.CancellationTest", 
diff --git a/src/python/grpcio/tests/unit/beta/_auth_test.py b/src/python/grpcio/tests/unit/beta/_auth_test.py
new file mode 100644
index 0000000..694928a
--- /dev/null
+++ b/src/python/grpcio/tests/unit/beta/_auth_test.py
@@ -0,0 +1,96 @@
+# Copyright 2016, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+#     * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#     * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+#     * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+"""Tests of standard AuthMetadataPlugins."""
+
+import collections
+import threading
+import unittest
+
+from grpc.beta import _auth
+
+
+class MockGoogleCreds(object):
+
+  def get_access_token(self):
+    token = collections.namedtuple('MockAccessTokenInfo',
+                                   ('access_token', 'expires_in'))
+    token.access_token = 'token'
+    return token
+
+
+class MockExceptionGoogleCreds(object):
+
+  def get_access_token(self):
+    raise Exception()
+
+
+class GoogleCallCredentialsTest(unittest.TestCase):
+
+  def test_google_call_credentials_success(self):
+    callback_event = threading.Event()
+
+    def mock_callback(metadata, error):
+      self.assertEqual(metadata, (('authorization', 'Bearer token'),))
+      self.assertIsNone(error)
+      callback_event.set()
+
+    call_creds = _auth.GoogleCallCredentials(MockGoogleCreds())
+    call_creds(None, mock_callback)
+    self.assertTrue(callback_event.wait(1.0))
+
+  def test_google_call_credentials_error(self):
+    callback_event = threading.Event()
+
+    def mock_callback(metadata, error):
+      self.assertIsNotNone(error)
+      callback_event.set()
+
+    call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds())
+    call_creds(None, mock_callback)
+    self.assertTrue(callback_event.wait(1.0))
+
+
+class AccessTokenCallCredentialsTest(unittest.TestCase):
+
+  def test_google_call_credentials_success(self):
+    callback_event = threading.Event()
+
+    def mock_callback(metadata, error):
+      self.assertEqual(metadata, (('authorization', 'Bearer token'),))
+      self.assertIsNone(error)
+      callback_event.set()
+
+    call_creds = _auth.AccessTokenCallCredentials('token')
+    call_creds(None, mock_callback)
+    self.assertTrue(callback_event.wait(1.0))
+
+
+if __name__ == '__main__':
+  unittest.main(verbosity=2)
diff --git a/src/python/grpcio/tests/unit/beta/_implementations_test.py b/src/python/grpcio/tests/unit/beta/_implementations_test.py
index 26be670..127f93e 100644
--- a/src/python/grpcio/tests/unit/beta/_implementations_test.py
+++ b/src/python/grpcio/tests/unit/beta/_implementations_test.py
@@ -29,8 +29,11 @@
 
 """Tests the implementations module of the gRPC Python Beta API."""
 
+import datetime
 import unittest
 
+from oauth2client import client as oauth2client_client
+
 from grpc.beta import implementations
 from tests.unit import resources
 
@@ -49,5 +52,19 @@
         channel_credentials, implementations.ChannelCredentials)
 
 
+class CallCredentialsTest(unittest.TestCase):
+
+  def test_google_call_credentials(self):
+    creds = oauth2client_client.GoogleCredentials(
+        'token', 'client_id', 'secret', 'refresh_token',
+        datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/',
+        'user_agent')
+    call_creds = implementations.google_call_credentials(creds)
+    self.assertIsInstance(call_creds, implementations.CallCredentials)
+
+  def test_access_token_call_credentials(self):
+    call_creds = implementations.access_token_call_credentials('token')
+    self.assertIsInstance(call_creds, implementations.CallCredentials)
+
 if __name__ == '__main__':
   unittest.main(verbosity=2)
diff --git a/tools/run_tests/run_interop_tests.py b/tools/run_tests/run_interop_tests.py
index edbdf05..053aabc 100755
--- a/tools/run_tests/run_interop_tests.py
+++ b/tools/run_tests/run_interop_tests.py
@@ -317,8 +317,7 @@
             'PYTHONPATH': '{}/src/python/gens'.format(DOCKER_WORKDIR_ROOT)}
 
   def unimplemented_test_cases(self):
-    return _SKIP_ADVANCED + _SKIP_COMPRESSION + ['jwt_token_creds',
-                                                 'per_rpc_creds']
+    return _SKIP_ADVANCED + _SKIP_COMPRESSION + ['jwt_token_creds']
 
   def unimplemented_test_cases_server(self):
     return _SKIP_ADVANCED + _SKIP_COMPRESSION