Extend local auth server to support OAuth login flow

The local auth server currently supports legacy ClientLogin only. This CL
adds support for the new OAuth-based Service login. This will allow the
local auth server to be used in PyAuto tests that require OAuth so that
policy can be fetched.

BUG=chromium-os:32036
TEST=Login against local auth server without --skip-oauth-login works
TEST=remote_tests.sh suite:smoke succeeds

Change-Id: I89f1154be7f7c52a941726fed8cd010a0ef399c9
Reviewed-on: https://gerrit.chromium.org/gerrit/24501
Reviewed-by: Bartosz Fabianowski <bartfab@chromium.org>
Tested-by: Bartosz Fabianowski <bartfab@chromium.org>
Commit-Ready: Bartosz Fabianowski <bartfab@chromium.org>
diff --git a/client/cros/auth_server.py b/client/cros/auth_server.py
index fc9f214d..089c938 100644
--- a/client/cros/auth_server.py
+++ b/client/cros/auth_server.py
@@ -2,13 +2,21 @@
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
 
-import httplib, logging, os, socket, stat, time
+import httplib, json, logging, os, socket, stat, time, urllib
 
 import common, constants, cryptohome, httpd
 from autotest_lib.client.bin import utils
 from autotest_lib.client.common_lib import error
 
 
+def _value(url_arg_value):
+    """Helper unifying the handling of GET and POST arguments."""
+    try:
+        return url_arg_value[0]
+    except AttributeError:
+        return url_arg_value.value
+
+
 class GoogleAuthServer(object):
     """A mock Google accounts server that can be run in a separate thread
     during autotests. By default, it returns happy-signals, accepting any
@@ -37,12 +45,19 @@
   <FORM action=%(form_url)s method=POST onsubmit='submitAndGo()'>
     <INPUT TYPE=text id="Email" name="Email">
     <INPUT TYPE=text id="Passwd" name="Passwd">
+    <P>%(error_message)s</P>
     <INPUT TYPE=hidden id="continue" name="continue" value=%(continue)s>
     <INPUT TYPE=Submit id="signIn">
   </FORM>
 </BODY>
 </HTML>
     """
+    __oauth1_request_token = 'oauth1_request_token'
+    __oauth1_access_token = 'oauth1_access_token'
+    __oauth1_access_token_secret = 'oauth1_access_token_secret'
+    __oauth2_auth_code = 'oauth2_auth_code'
+    __oauth2_refresh_token = 'oauth2_refresh_token'
+    __oauth2_access_token = 'oauth2_access_token'
     __issue_auth_token_miss_count = 0
     __token_auth_miss_count = 0
 
@@ -52,19 +67,27 @@
                  key_path='/etc/fake_root_ca/mock_server.key',
                  ssl_port=443,
                  port=80,
-                 cl_responder=None,
-                 it_responder=None,
-                 pl_responder=None,
-                 ta_responder=None):
+                 authenticator=None):
         self._service_login = constants.SERVICE_LOGIN_URL
         self._service_login_new = constants.SERVICE_LOGIN_NEW_URL
-        self._process_login = constants.PROCESS_LOGIN_URL
-        self._process_login_new = constants.PROCESS_LOGIN_NEW_URL
+        self._service_login_auth = constants.SERVICE_LOGIN_AUTH_URL
+
+        self._oauth1_get_request_token = constants.OAUTH1_GET_REQUEST_TOKEN_URL
+        self._oauth1_get_access_token = constants.OAUTH1_GET_ACCESS_TOKEN_URL
+        self._oauth1_get_access_token_new = \
+            constants.OAUTH1_GET_ACCESS_TOKEN_NEW_URL
+        self._oauth1_login = constants.OAUTH1_LOGIN_URL
+        self._oauth1_login_new = constants.OAUTH1_LOGIN_NEW_URL
+
+        self._oauth2_wrap_bridge = constants.OAUTH2_WRAP_BRIDGE_URL
+        self._oauth2_wrap_bridge_new = constants.OAUTH2_WRAP_BRIDGE_NEW_URL
+        self._oauth2_get_auth_code = constants.OAUTH2_GET_AUTH_CODE_URL
+        self._oauth2_get_token = constants.OAUTH2_GET_TOKEN_URL
 
         self._client_login = constants.CLIENT_LOGIN_URL
         self._client_login_new = constants.CLIENT_LOGIN_NEW_URL
         self._issue_token = constants.ISSUE_AUTH_TOKEN_URL
-        self._issue_token_new = constants.ISSUE_AUTH_TOKEN_URL
+        self._issue_token_new = constants.ISSUE_AUTH_TOKEN_NEW_URL
         self._token_auth = constants.TOKEN_AUTH_URL
         self._token_auth_new = constants.TOKEN_AUTH_NEW_URL
         self._test_over = '/webhp'
@@ -77,29 +100,56 @@
         sa = self._testServer.getsockname()
         logging.info('Serving HTTPS on %s, port %s' % (sa[0], sa[1]))
 
-        if cl_responder is None:
-            cl_responder = self.client_login_responder
-        if it_responder is None:
-            it_responder = self.issue_token_responder
-        if pl_responder is None:
-            pl_responder = self.process_login_responder
-        if ta_responder is None:
-            ta_responder = self.token_auth_responder
+        if authenticator is None:
+            authenticator = self.authenticator
+        self._authenticator = authenticator
 
         self._testServer.add_url_handler(self._service_login,
-                                         self.__service_login_responder)
+                                         self._service_login_responder)
         self._testServer.add_url_handler(self._service_login_new,
-                                         self.__service_login_responder_new)
-        self._testServer.add_url_handler(self._process_login, pl_responder)
-        self._testServer.add_url_handler(self._process_login_new, pl_responder)
+                                         self._service_login_responder)
+        self._testServer.add_url_handler(self._service_login_auth,
+                                         self._service_login_auth_responder)
 
-        self._testServer.add_url_handler(self._client_login, cl_responder)
-        self._testServer.add_url_handler(self._client_login_new, cl_responder)
-        self._testServer.add_url_handler(self._issue_token, it_responder)
-        self._testServer.add_url_handler(self._issue_token_new, it_responder)
-        self._testServer.add_url_handler(self._token_auth, ta_responder)
-        self._testServer.add_url_handler(self._token_auth_new, ta_responder)
+        self._testServer.add_url_handler(
+            self._oauth1_get_request_token,
+            self._oauth1_get_request_token_responder)
+        self._testServer.add_url_handler(
+            self._oauth1_get_access_token,
+            self._oauth1_get_access_token_responder)
+        self._testServer.add_url_handler(
+            self._oauth1_get_access_token_new,
+            self._oauth1_get_access_token_responder)
+        self._testServer.add_url_handler(self._oauth1_login,
+                                         self._oauth1_login_responder)
+        self._testServer.add_url_handler(self._oauth1_login_new,
+                                         self._oauth1_login_responder)
 
+        self._testServer.add_url_handler(self._oauth2_wrap_bridge,
+                                         self._oauth2_wrap_bridge_responder)
+        self._testServer.add_url_handler(self._oauth2_wrap_bridge_new,
+                                         self._oauth2_wrap_bridge_responder)
+        self._testServer.add_url_handler(self._oauth2_get_auth_code,
+                                         self._oauth2_get_auth_code_responder)
+        self._testServer.add_url_handler(self._oauth2_get_token,
+                                         self._oauth2_get_token_responder)
+
+        self._testServer.add_url_handler(self._client_login,
+                                         self._client_login_responder)
+        self._testServer.add_url_handler(self._client_login_new,
+                                         self._client_login_responder)
+        self._testServer.add_url_handler(self._issue_token,
+                                         self._issue_token_responder)
+        self._testServer.add_url_handler(self._issue_token_new,
+                                         self._issue_token_responder)
+        self._testServer.add_url_handler(self._token_auth,
+                                         self._token_auth_responder)
+        self._testServer.add_url_handler(self._token_auth_new,
+                                         self._token_auth_responder)
+
+        self._service_latch = self._testServer.add_wait_url(self._service_login)
+        self._service_new_latch = self._testServer.add_wait_url(
+            self._service_login_new)
         self._client_latch = self._testServer.add_wait_url(self._client_login)
         self._client_new_latch = self._testServer.add_wait_url(
             self._client_login_new)
@@ -111,7 +161,7 @@
         self._testHttpServer.add_url_handler(self._test_over,
                                              self.__test_over_responder)
         self._testHttpServer.add_url_handler(constants.PORTAL_CHECK_URL,
-                                             self.portal_check_responder)
+                                             self._portal_check_responder)
         self._over_latch = self._testHttpServer.add_wait_url(self._test_over)
 
 
@@ -125,6 +175,14 @@
         self._testHttpServer.stop()
 
 
+    def wait_for_service_login(self, timeout=10):
+        self._service_new_latch.wait(timeout)
+        if not self._service_new_latch.is_set():
+            self._service_latch.wait(timeout)
+            if not self._service_latch.is_set():
+                raise error.TestError('Never hit ServiceLogin endpoint.')
+
+
     def wait_for_client_login(self, timeout=10):
         self._client_new_latch.wait(timeout)
         if not self._client_new_latch.is_set():
@@ -158,80 +216,214 @@
         return results
 
 
-    def client_login_responder(self, handler, url_args):
-        logging.debug(url_args)
+    def authenticator(self, email, password):
+      return True
+
+
+    def _ensure_params_provided(self, handler, url_args, params):
+      for param in params:
+            if not param in url_args:
+                handler.send_response(httplib.FORBIDDEN)
+                handler.end_headers()
+                raise error.TestError(
+                    '%s did not receive a %s param.' % (handler.path, param))
+
+
+    def _return_login_form(self, handler, error_message, continue_url):
+        handler.send_response(httplib.OK)
+        handler.end_headers()
+        handler.wfile.write(self.__service_login_html % {
+            'form_url': self._service_login_auth,
+            'error_message': error_message,
+            'continue': continue_url})
+
+
+    def _log(self, handler, url_args):
+        logging.debug('%s: %s' % (handler.path, url_args))
+
+
+    def _service_login_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_params_provided(handler, url_args, ['continue'])
+        self._return_login_form(handler, '', _value(url_args['continue']))
+
+
+    def _service_login_auth_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_params_provided(handler,
+                                     url_args,
+                                     ['continue', 'Email', 'Passwd'])
+        if self._authenticator(_value(url_args['Email']),
+                               _value(url_args['Passwd'])):
+            handler.send_response(httplib.SEE_OTHER)
+            handler.send_header('Location', _value(url_args['continue']))
+            handler.end_headers()
+        else:
+            self._return_login_form(handler,
+                                    constants.SERVICE_LOGIN_AUTH_ERROR,
+                                    _value(['continue']))
+
+
+    def _oauth1_get_request_token_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_params_provided(handler,
+                                     url_args,
+                                     ['scope', 'xoauth_display_name'])
+        handler.send_response(httplib.OK)
+        handler.send_header('Set-Cookie',
+                            'oauth_token=%s; Path=%s; Secure; HttpOnly' %
+                                (self.__oauth1_request_token, handler.path))
+        handler.end_headers()
+
+
+    def _ensure_oauth1_params_valid(self, handler, url_args, expected_token):
+        self._ensure_params_provided(handler,
+                                     url_args,
+                                     ['oauth_consumer_key',
+                                      'oauth_token',
+                                      'oauth_signature_method',
+                                      'oauth_signature',
+                                      'oauth_timestamp',
+                                      'oauth_nonce'])
+        if not ('anonymous' == _value(url_args['oauth_consumer_key']) and
+                expected_token == _value(url_args['oauth_token'])):
+            raise error.TestError(
+                '%s called with incorrect params.' % handler.path)
+
+
+    def _oauth1_get_access_token_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_oauth1_params_valid(handler,
+                                         url_args,
+                                         self.__oauth1_request_token)
+        handler.send_response(httplib.OK)
+        handler.send_header('Content-Type', 'application/x-www-form-urlencoded')
+        handler.end_headers()
+        handler.wfile.write(urllib.urlencode({
+            'oauth_token': self.__oauth1_access_token,
+            'oauth_token_secret': self.__oauth1_access_token_secret}))
+
+
+    def _oauth1_login_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_oauth1_params_valid(handler,
+                                         url_args,
+                                         self.__oauth1_access_token)
         handler.send_response(httplib.OK)
         handler.end_headers()
         handler.wfile.write('SID=%s\n' % self.sid)
         handler.wfile.write('LSID=%s\n' % self.lsid)
+        handler.wfile.write('Auth=%s\n' % self.token)
 
 
-    def issue_token_responder(self, handler, url_args):
-        logging.debug(url_args)
-        if url_args['service'].value != constants.LOGIN_SERVICE:
-            handler.send_response(httplib.FORBIDDEN)
-            handler.end_headers()
-            handler.wfile.write(constants.LOGIN_ERROR)
-            return
+    def _oauth2_wrap_bridge_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_oauth1_params_valid(handler,
+                                         url_args,
+                                         self.__oauth1_access_token)
+        handler.send_response(httplib.OK)
+        handler.send_header('Content-Type', 'application/x-www-form-urlencoded')
+        handler.end_headers()
+        handler.wfile.write(urllib.urlencode({
+            'wrap_access_token': self.__oauth2_access_token,
+            'wrap_access_token_expires_in': '3600'}))
 
-        if not (self.sid == url_args['SID'].value and
-                self.lsid == url_args['LSID'].value):
-            raise error.TestError('IssueAuthToken called with incorrect args')
+
+    def _ensure_oauth2_params_valid(self, handler, url_args):
+        self._ensure_params_provided(handler, url_args, ['scope', 'client_id'])
+        if constants.OAUTH2_CLIENT_ID != _value(url_args['client_id']):
+            raise error.TestError(
+                '%s called with incorrect params.' % handler.path)
+
+
+    def _oauth2_get_auth_code_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_oauth2_params_valid(handler, url_args)
+        handler.send_response(httplib.OK)
+        handler.send_header('Set-Cookie',
+                            'oauth_code=%s; Path=%s; Secure; HttpOnly' %
+                                (self.__oauth2_auth_code, handler.path))
+        handler.end_headers()
+
+
+    def _oauth2_get_token_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_oauth2_params_valid(handler, url_args)
+        self._ensure_params_provided(handler,
+                                     url_args,
+                                     ['grant_type', 'client_secret'])
+        if constants.OAUTH2_CLIENT_SECRET != _value(url_args['client_secret']):
+            raise error.TestError(
+                '%s called with incorrect params.' % handler.path)
+        if 'authorization_code' == _value(url_args['grant_type']):
+            self._ensure_params_provided(handler, url_args, ['code'])
+            if self.__oauth2_auth_code != _value(url_args['code']):
+                raise error.TestError(
+                    '%s called with incorrect params.' % handler.path)
+        elif 'refresh_token' == _value(url_args['grant_type']):
+            self._ensure_params_provided(handler, url_args, ['refresh_token'])
+            if self.__oauth2_refresh_token != _value(url_args['refresh_token']):
+                raise error.TestError(
+                    '%s called with incorrect params.' % handler.path)
+        else:
+            raise error.TestError(
+                '%s called with incorrect params.' % handler.path)
         handler.send_response(httplib.OK)
         handler.end_headers()
-        handler.wfile.write(self.token)
+        handler.wfile.write(json.dumps({
+            'refresh_token': self.__oauth2_refresh_token,
+            'access_token': self.__oauth2_access_token,
+            'expires_in': 3600}))
 
 
-    def process_login_responder(self, handler, url_args):
-        logging.debug(url_args)
-        if not 'continue' in url_args:
+    def _client_login_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_params_provided(handler, url_args, ['Email', 'Passwd'])
+        if self._authenticator(_value(url_args['Email']),
+                               _value(url_args['Passwd'])):
+            handler.send_response(httplib.OK)
+            handler.end_headers()
+            handler.wfile.write('SID=%s\n' % self.sid)
+            handler.wfile.write('LSID=%s\n' % self.lsid)
+        else:
             handler.send_response(httplib.FORBIDDEN)
             handler.end_headers()
-            raise error.TestError('ServiceLogin did not pass a continue param')
+            handler.wfile.write('Error=BadAuthentication')
+
+
+    def _issue_token_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_params_provided(handler,
+                                     url_args,
+                                     ['service', 'SID', 'LSID'])
+        if not (self.sid == _value(url_args['SID']) and
+                self.lsid == _value(url_args['LSID'])):
+            raise error.TestError(
+                '%s called with incorrect params.' % handler.path)
+        # Block Chrome sync as we do not currently mock the server for it.
+        if _value(url_args['service']) in ['chromiumsync', 'mobilesync']:
+            handler.send_response(httplib.FORBIDDEN)
+            handler.end_headers()
+            handler.wfile.write('Error=ServiceUnavailable')
+        else:
+            handler.send_response(httplib.OK)
+            handler.end_headers()
+            handler.wfile.write(self.token)
+
+
+    def _token_auth_responder(self, handler, url_args):
+        self._log(handler, url_args)
+        self._ensure_params_provided(handler, url_args, ['auth', 'continue'])
+        if not self.token == _value(url_args['auth']):
+            raise error.TestError(
+                '%s called with incorrect param.' % handler.path)
         handler.send_response(httplib.SEE_OTHER)
-        handler.send_header('Location', url_args['continue'].value)
+        handler.send_header('Location', _value(url_args['continue']))
         handler.end_headers()
 
 
-    def __service_login_responder(self, handler, url_args):
-        logging.debug(url_args)
-        if not 'continue' in url_args:
-            handler.send_response(httplib.FORBIDDEN)
-            handler.end_headers()
-            raise error.TestError('ServiceLogin called with no continue param')
-        handler.send_response(httplib.OK)
-        handler.end_headers()
-        handler.wfile.write(self.__service_login_html % {
-            'form_url': self._process_login,
-            'continue': url_args['continue'][0] })
-
-
-    def __service_login_responder_new(self, handler, url_args):
-        logging.debug(url_args)
-        if not 'continue' in url_args:
-            handler.send_response(httplib.FORBIDDEN)
-            handler.end_headers()
-            raise error.TestError('ServiceLogin called with no continue param')
-        handler.send_response(httplib.OK)
-        handler.end_headers()
-        handler.wfile.write(self.__service_login_html % {
-            'form_url': self._process_login_new,
-            'continue': url_args['continue'][0] })
-
-
-    def token_auth_responder(self, handler, url_args):
-        logging.debug(url_args)
-        if not self.token == url_args['auth'][0]:
-            raise error.TestError('TokenAuth called with incorrect args')
-        if not 'continue' in url_args:
-            raise error.TestError('TokenAuth called with no continue param')
-        handler.send_response(httplib.SEE_OTHER)
-        handler.send_header('Location', url_args['continue'][0])
-        handler.end_headers()
-
-
-    def portal_check_responder(self, handler, url_args):
-        logging.debug('Handling captive portal check')
+    def _portal_check_responder(self, handler, url_args):
+        logging.debug('Handling captive portal check.')
         handler.send_response(httplib.NO_CONTENT)
         handler.end_headers()
 
diff --git a/client/cros/constants.py b/client/cros/constants.py
index e32599a..0d3467d 100644
--- a/client/cros/constants.py
+++ b/client/cros/constants.py
@@ -63,7 +63,6 @@
 LOGGED_IN_MAGIC_FILE = '/var/run/state/logged-in'
 
 LOGIN_PROFILE = USER_DATA_DIR + '/Default'
-LOGIN_SERVICE = 'gaia'
 LOGIN_ERROR = 'Error=BadAuthentication'
 LOGIN_PROMPT_VISIBLE_MAGIC_FILE = '/tmp/uptime-login-prompt-visible'
 LOGIN_TRUST_ROOTS = '/etc/login_trust_root.pem'
@@ -77,14 +76,28 @@
 ISSUE_AUTH_TOKEN_URL = '/accounts/IssueAuthToken'
 ISSUE_AUTH_TOKEN_NEW_URL = '/IssueAuthToken'
 
+OAUTH1_GET_REQUEST_TOKEN_URL = '/accounts/o8/GetOAuthToken'
+OAUTH1_GET_ACCESS_TOKEN_URL = '/accounts/OAuthGetAccessToken'
+OAUTH1_GET_ACCESS_TOKEN_NEW_URL = '/OAuthGetAccessToken'
+OAUTH1_LOGIN_URL = '/accounts/OAuthLogin'
+OAUTH1_LOGIN_NEW_URL = '/OAuthLogin'
+
+OAUTH2_CLIENT_ID = '77185425430.apps.googleusercontent.com'
+OAUTH2_CLIENT_SECRET = 'OTJgUOQcT7lO7GsGZq2G4IlT'
+OAUTH2_WRAP_BRIDGE_URL = '/accounts/OAuthWrapBridge'
+OAUTH2_WRAP_BRIDGE_NEW_URL = '/OAuthWrapBridge'
+OAUTH2_GET_AUTH_CODE_URL = '/o/oauth2/programmatic_auth'
+OAUTH2_GET_TOKEN_URL = '/o/oauth2/token'
+
 OWNER_KEY_FILE = WHITELIST_DIR + '/owner.key'
 
 PORTAL_CHECK_URL = '/generate_204'
-PROCESS_LOGIN_URL = '/accounts/ProcessServiceLogin'
-PROCESS_LOGIN_NEW_URL = '/ProcessServiceLogin'
 
 SERVICE_LOGIN_URL = '/accounts/ServiceLogin'
 SERVICE_LOGIN_NEW_URL = '/ServiceLogin'
+SERVICE_LOGIN_AUTH_URL = '/ServiceLoginAuth'
+SERVICE_LOGIN_AUTH_ERROR = 'The username or password you entered is incorrect.'
+
 SESSION_MANAGER = 'session_manager'
 SESSION_MANAGER_LOG = '/var/log/session_manager'
 SIGNED_POLICY_FILE = WHITELIST_DIR + '/policy'
diff --git a/client/cros/cros_ui_test.py b/client/cros/cros_ui_test.py
index 1ca61ea..1750ec4 100644
--- a/client/cros/cros_ui_test.py
+++ b/client/cros/cros_ui_test.py
@@ -2,14 +2,12 @@
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
 
-import dbus, glob, logging, os, re, shutil, socket, stat, subprocess, sys, time
-from dbus.mainloop.glib import DBusGMainLoop
+import glob, logging, os, re, shutil, subprocess, sys, time
 
 import auth_server, common, constants, cros_logging, cros_ui, cryptohome
-import dns_server, flimflam_test_path, login, ownership, pyauto_test
+import dns_server, login, ownership, pyauto_test
 from autotest_lib.client.bin import utils
 from autotest_lib.client.common_lib import error
-import flimflam # Requires flimflam_test_path to be imported first.
 
 class UITest(pyauto_test.PyAutoTest):
     """Base class for tests that drive some portion of the user interface.
@@ -52,155 +50,20 @@
         '/sys/kernel/debug/tracing/events/signal/signal_generate/filter'
     _ftrace_trace_file = '/sys/kernel/debug/tracing/trace'
 
-    # This is a symlink.  We look up the real path at runtime by following it.
-    _resolv_test_file = 'resolv.conf.test'
-    _resolv_bak_file = 'resolv.conf.bak'
-
     _last_chrome_log = ''
 
 
-    def listen_to_signal(self, callback, signal, interface):
-        """Listens to the given |signal| that is sent to power manager.
-        """
-        self._system_bus.add_signal_receiver(
-            handler_function=callback,
-            signal_name=signal,
-            dbus_interface=interface,
-            bus_name=None,
-            path='/')
-
-
-    def __connect_to_flimflam(self):
-        """Connect to the network manager via DBus.
-
-        Stores dbus connection in self._flim upon success, throws on failure.
-        """
-        self._bus_loop = DBusGMainLoop(set_as_default=True)
-        self._system_bus = dbus.SystemBus(mainloop=self._bus_loop)
-        self._flim = flimflam.FlimFlam(self._system_bus)
-
-
-    def __get_host_by_name(self, hostname):
-        """Resolve the dotted-quad IPv4 address of |hostname|
-
-        This used to use suave python code, like this:
-            hosts = socket.getaddrinfo(hostname, 80, socket.AF_INET)
-            (fam, socktype, proto, canonname, (host, port)) = hosts[0]
-            return host
-
-        But that hangs sometimes, and we don't understand why.  So, use
-        a subprocess with a timeout.
-        """
-        try:
-            host = utils.system_output('%s -c "import socket; '
-                                       'print socket.gethostbyname(\'%s\')"' % (
-                                       sys.executable, hostname),
-                                       ignore_status=True, timeout=2)
-        except Exception as e:
-            logging.warning(e)
-            return None
-        return host or None
-
-
-    def __attempt_resolve(self, hostname, ip, expected=True):
-        logging.debug('Attempting to resolve %s to %s' % (hostname, ip))
-        try:
-            host = self.__get_host_by_name(hostname)
-            logging.debug('Resolve attempt for %s got %s' % (hostname, host))
-            return host and (host == ip) == expected
-        except socket.gaierror as err:
-            logging.error(err)
-
-
-    def use_local_dns(self, dns_port=53):
-        """Set all devices to use our in-process mock DNS server.
-        """
-        self._dnsServer = dns_server.LocalDns(fake_ip='127.0.0.1',
-                                              local_port=dns_port)
-        self._dnsServer.run()
-        # Turn off captive portal checking, until we fix
-        # http://code.google.com/p/chromium-os/issues/detail?id=19640
-        self.check_portal_list = self._flim.GetCheckPortalList()
-        self._flim.SetCheckPortalList('')
-        # Set all devices to use locally-running DNS server.
-        try:
-            # Follow resolv.conf symlink.
-            resolv = os.path.realpath(constants.RESOLV_CONF_FILE)
-            # Grab path to the real file, do following work in that directory.
-            resolv_dir = os.path.dirname(resolv)
-            resolv_test = os.path.join(resolv_dir, self._resolv_test_file)
-            resolv_bak = os.path.join(resolv_dir, self._resolv_bak_file)
-            resolv_contents = 'nameserver 127.0.0.1'
-            # Test to make sure the current resolv.conf isn't already our
-            # specially modified version.  If this is the case, we have
-            # probably been interrupted while in the middle of this test
-            # in a previous run.  The last thing we want to do at this point
-            # is to overwrite a legitimate backup.
-            if (utils.read_one_line(resolv) == resolv_contents and
-                os.path.exists(resolv_bak)):
-                logging.error('Current resolv.conf is setup for our local '
-                              'server, and a backup already exists!  '
-                              'Skipping the backup step.')
-            else:
-              # Back up the current resolv.conf.
-              os.rename(resolv, resolv_bak)
-            # To stop flimflam from editing resolv.conf while we're working
-            # with it, we want to make the directory -r-x-r-x-r-x.  Open an
-            # fd to the file first, so that we'll retain the ability to
-            # alter it.
-            resolv_fd = open(resolv, 'w')
-            self._resolv_dir_mode = os.stat(resolv_dir).st_mode
-            os.chmod(resolv_dir, (stat.S_IRUSR | stat.S_IXUSR |
-                                  stat.S_IRGRP | stat.S_IXGRP |
-                                  stat.S_IROTH | stat.S_IXOTH))
-            resolv_fd.write(resolv_contents)
-            resolv_fd.close()
-            assert utils.read_one_line(resolv) == resolv_contents
-        except Exception as e:
-            logging.error(str(e))
-            raise e
-
-        utils.poll_for_condition(
-            lambda: self.__attempt_resolve('www.google.com.', '127.0.0.1'),
-            utils.TimeoutError('Timed out waiting for DNS changes.'),
-            timeout=10)
-
-
-    def revert_dns(self):
-        """Clear the custom DNS setting for all devices and force them to use
-        DHCP to pull the network's real settings again.
-        """
-        try:
-            # Follow resolv.conf symlink.
-            resolv = os.path.realpath(constants.RESOLV_CONF_FILE)
-            # Grab path to the real file, do following work in that directory.
-            resolv_dir = os.path.dirname(resolv)
-            resolv_bak = os.path.join(resolv_dir, self._resolv_bak_file)
-            os.chmod(resolv_dir, self._resolv_dir_mode)
-            os.rename(resolv_bak, resolv)
-
-            utils.poll_for_condition(
-                lambda: self.__attempt_resolve('www.google.com.',
-                                               '127.0.0.1',
-                                               expected=False),
-                utils.TimeoutError('Timed out waiting to revert DNS.  '
-                                   'resolv.conf contents are: ' +
-                                   utils.read_one_line(resolv)),
-                timeout=10)
-        finally:
-            # Set captive portal checking to whatever it was at the start.
-            self._flim.SetCheckPortalList(self.check_portal_list)
-
-
-    def start_authserver(self):
+    def start_authserver(self, authenticator=None):
         """Spin up a local mock of the Google Accounts server, then spin up
         a local fake DNS server and tell the networking stack to use it.  This
         will trick Chrome into talking to our mock when we login.
         Subclasses can override this method to change this behavior.
         """
-        self._authServer = auth_server.GoogleAuthServer()
+        self._authServer = auth_server.GoogleAuthServer(
+            authenticator=authenticator)
         self._authServer.run()
-        self.use_local_dns()
+        self._dnsServer = dns_server.LocalDns()
+        self._dnsServer.run()
 
 
     def stop_authserver(self):
@@ -209,8 +72,8 @@
         method as well.
         """
         if hasattr(self, '_authServer'):
-            self.revert_dns()
             self._authServer.stop()
+        if hasattr(self, '_dnsServer'):
             self._dnsServer.stop()
 
 
@@ -376,10 +239,10 @@
 
         Authentication is not performed against live servers.  Instead, we spin
         up a local DNS server that will lie and say that all sites resolve to
-        127.0.0.1.  We use DBus to tell flimflam to use this DNS server to
-        resolve addresses.  We then spin up a local httpd that will respond
-        to queries at the Google Accounts endpoints.  We clear the DNS setting
-        and tear down these servers in cleanup().
+        127.0.0.1.  The DNS server tells flimflam via DBus that it should be
+        used to resolve addresses.  We then spin up a local httpd that will
+        respond to queries at the Google Accounts endpoints.  We clear the DNS
+        setting and tear down these servers in cleanup().
 
         Args:
             creds: String specifying the credentials for this test case.  Can
@@ -397,8 +260,6 @@
         self._log_reader = cros_logging.LogReader()
         self._log_reader.set_start_by_current()
 
-        self.__connect_to_flimflam()
-
         if creds:
             self.start_authserver()
 
diff --git a/client/cros/dns_server.py b/client/cros/dns_server.py
index 90c2f9a..2ee1036 100644
--- a/client/cros/dns_server.py
+++ b/client/cros/dns_server.py
@@ -2,15 +2,28 @@
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
 
-import logging, threading, time
+import dbus, logging, os, socket, stat, sys, threading, time
+from dbus.mainloop.glib import DBusGMainLoop
 
-import common
+import common, constants, flimflam_test_path
 from autotest_lib.client.bin import utils
+import flimflam  # Requires flimflam_test_path to be imported first.
 
 class LocalDns(object):
-    """a wrapper around miniFakeDns that handles managing running the server
-    in a separate thread.
+    """A wrapper around miniFakeDns that runs the server in a separate thread
+    and redirects all DNS queries to it.
     """
+    # This is a symlink.  We look up the real path at runtime by following it.
+    _resolv_bak_file = 'resolv.conf.bak'
+
+    def __connect_to_flimflam(self):
+        """Connect to the network manager via DBus.
+
+        Stores dbus connection in self._flim upon success, throws on failure.
+        """
+        self._bus_loop = DBusGMainLoop(set_as_default=True)
+        self._system_bus = dbus.SystemBus(mainloop=self._bus_loop)
+        self._flim = flimflam.FlimFlam(self._system_bus)
 
     def __init__(self, fake_ip="127.0.0.1", local_port=53):
         import miniFakeDns  # So we don't need to install it in the chroot.
@@ -18,12 +31,106 @@
         self._stopper = threading.Event()
         self._thread = threading.Thread(target=self._dns.run,
                                         args=(self._stopper,))
+        self.__connect_to_flimflam()
 
+    def __get_host_by_name(self, hostname):
+        """Resolve the dotted-quad IPv4 address of |hostname|
+
+        This used to use suave python code, like this:
+            hosts = socket.getaddrinfo(hostname, 80, socket.AF_INET)
+            (fam, socktype, proto, canonname, (host, port)) = hosts[0]
+            return host
+
+        But that hangs sometimes, and we don't understand why.  So, use
+        a subprocess with a timeout.
+        """
+        try:
+            host = utils.system_output('%s -c "import socket; '
+                                       'print socket.gethostbyname(\'%s\')"' % (
+                                       sys.executable, hostname),
+                                       ignore_status=True, timeout=2)
+        except Exception as e:
+            logging.warning(e)
+            return None
+        return host or None
+
+    def __attempt_resolve(self, hostname, ip, expected=True):
+        logging.debug('Attempting to resolve %s to %s' % (hostname, ip))
+        host = self.__get_host_by_name(hostname)
+        logging.debug('Resolve attempt for %s got %s' % (hostname, host))
+        return host and (host == ip) == expected
 
     def run(self):
+        """Start the mock DNS server and redirect all queries to it."""
         self._thread.start()
+        # Turn off captive portal checking, until we fix
+        # http://code.google.com/p/chromium-os/issues/detail?id=19640
+        self.check_portal_list = self._flim.GetCheckPortalList()
+        self._flim.SetCheckPortalList('')
+        # Redirect all DNS queries to the mock DNS server.
+        try:
+            # Follow resolv.conf symlink.
+            resolv = os.path.realpath(constants.RESOLV_CONF_FILE)
+            # Grab path to the real file, do following work in that directory.
+            resolv_dir = os.path.dirname(resolv)
+            resolv_bak = os.path.join(resolv_dir, self._resolv_bak_file)
+            resolv_contents = 'nameserver 127.0.0.1'
+            # Test to make sure the current resolv.conf isn't already our
+            # specially modified version.  If this is the case, we have
+            # probably been interrupted while in the middle of this test
+            # in a previous run.  The last thing we want to do at this point
+            # is to overwrite a legitimate backup.
+            if (utils.read_one_line(resolv) == resolv_contents and
+                os.path.exists(resolv_bak)):
+                logging.error('Current resolv.conf is setup for our local '
+                              'server, and a backup already exists!  '
+                              'Skipping the backup step.')
+            else:
+                # Back up the current resolv.conf.
+                os.rename(resolv, resolv_bak)
+            # To stop flimflam from editing resolv.conf while we're working
+            # with it, we want to make the directory -r-xr-xr-x.  Open an
+            # fd to the file first, so that we'll retain the ability to
+            # alter it.
+            resolv_fd = open(resolv, 'w')
+            self._resolv_dir_mode = os.stat(resolv_dir).st_mode
+            os.chmod(resolv_dir, (stat.S_IRUSR | stat.S_IXUSR |
+                                  stat.S_IRGRP | stat.S_IXGRP |
+                                  stat.S_IROTH | stat.S_IXOTH))
+            resolv_fd.write(resolv_contents)
+            resolv_fd.close()
+            assert utils.read_one_line(resolv) == resolv_contents
+        except Exception as e:
+            logging.error(str(e))
+            raise e
 
+        utils.poll_for_condition(
+            lambda: self.__attempt_resolve('www.google.com.', '127.0.0.1'),
+            utils.TimeoutError('Timed out waiting for DNS changes.'),
+            timeout=10)
 
     def stop(self):
-        self._stopper.set()
-        self._thread.join()
+        """Restore the backed-up DNS settings and stop the mock DNS server."""
+        try:
+            # Follow resolv.conf symlink.
+            resolv = os.path.realpath(constants.RESOLV_CONF_FILE)
+            # Grab path to the real file, do following work in that directory.
+            resolv_dir = os.path.dirname(resolv)
+            resolv_bak = os.path.join(resolv_dir, self._resolv_bak_file)
+            os.chmod(resolv_dir, self._resolv_dir_mode)
+            os.rename(resolv_bak, resolv)
+
+            utils.poll_for_condition(
+                lambda: self.__attempt_resolve('www.google.com.',
+                                               '127.0.0.1',
+                                               expected=False),
+                utils.TimeoutError('Timed out waiting to revert DNS.  '
+                                   'resolv.conf contents are: ' +
+                                   utils.read_one_line(resolv)),
+                timeout=10)
+        finally:
+            # Set captive portal checking to whatever it was at the start.
+            self._flim.SetCheckPortalList(self.check_portal_list)
+            # Stop the DNS server.
+            self._stopper.set()
+            self._thread.join()