Merge pull request #8073 from daniel-j-born/handshake_shutdown

Safe server shutdown.
diff --git a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c
index da3e284..563271f 100644
--- a/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c
+++ b/src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c
@@ -61,13 +61,12 @@
   grpc_server_credentials *creds;
   bool is_shutdown;
   gpr_mu mu;
-  gpr_refcount refcount;
-  grpc_closure destroy_closure;
-  grpc_closure *destroy_callback;
+  grpc_closure tcp_server_shutdown_complete;
+  grpc_closure *server_destroy_listener_done;
 } server_secure_state;
 
 typedef struct server_secure_connect {
-  server_secure_state *state;
+  server_secure_state *server_state;
   grpc_pollset *accepting_pollset;
   grpc_tcp_server_acceptor *acceptor;
   grpc_handshake_manager *handshake_mgr;
@@ -77,39 +76,28 @@
   grpc_channel_args *args;
 } server_secure_connect;
 
-static void state_ref(server_secure_state *state) { gpr_ref(&state->refcount); }
-
-static void state_unref(server_secure_state *state) {
-  if (gpr_unref(&state->refcount)) {
-    /* ensure all threads have unlocked */
-    gpr_mu_lock(&state->mu);
-    gpr_mu_unlock(&state->mu);
-    /* clean up */
-    GRPC_SECURITY_CONNECTOR_UNREF(&state->sc->base, "server");
-    grpc_server_credentials_unref(state->creds);
-    gpr_free(state);
-  }
-}
-
 static void on_secure_handshake_done(grpc_exec_ctx *exec_ctx, void *statep,
                                      grpc_security_status status,
                                      grpc_endpoint *secure_endpoint,
                                      grpc_auth_context *auth_context) {
-  server_secure_connect *state = statep;
+  server_secure_connect *connection_state = statep;
   if (status == GRPC_SECURITY_OK) {
     if (secure_endpoint) {
-      gpr_mu_lock(&state->state->mu);
-      if (!state->state->is_shutdown) {
+      gpr_mu_lock(&connection_state->server_state->mu);
+      if (!connection_state->server_state->is_shutdown) {
         grpc_transport *transport = grpc_create_chttp2_transport(
-            exec_ctx, grpc_server_get_channel_args(state->state->server),
+            exec_ctx, grpc_server_get_channel_args(
+                          connection_state->server_state->server),
             secure_endpoint, 0);
         grpc_arg args_to_add[2];
-        args_to_add[0] = grpc_server_credentials_to_arg(state->state->creds);
+        args_to_add[0] = grpc_server_credentials_to_arg(
+            connection_state->server_state->creds);
         args_to_add[1] = grpc_auth_context_to_arg(auth_context);
         grpc_channel_args *args_copy = grpc_channel_args_copy_and_add(
-            state->args, args_to_add, GPR_ARRAY_SIZE(args_to_add));
-        grpc_server_setup_transport(exec_ctx, state->state->server, transport,
-                                    state->accepting_pollset, args_copy);
+            connection_state->args, args_to_add, GPR_ARRAY_SIZE(args_to_add));
+        grpc_server_setup_transport(
+            exec_ctx, connection_state->server_state->server, transport,
+            connection_state->accepting_pollset, args_copy);
         grpc_channel_args_destroy(args_copy);
         grpc_chttp2_transport_start_reading(exec_ctx, transport, NULL);
       } else {
@@ -117,21 +105,21 @@
          * gone away. */
         grpc_endpoint_destroy(exec_ctx, secure_endpoint);
       }
-      gpr_mu_unlock(&state->state->mu);
+      gpr_mu_unlock(&connection_state->server_state->mu);
     }
   } else {
     gpr_log(GPR_ERROR, "Secure transport failed with error %d", status);
   }
-  grpc_channel_args_destroy(state->args);
-  state_unref(state->state);
-  gpr_free(state);
+  grpc_channel_args_destroy(connection_state->args);
+  grpc_tcp_server_unref(exec_ctx, connection_state->server_state->tcp);
+  gpr_free(connection_state);
 }
 
 static void on_handshake_done(grpc_exec_ctx *exec_ctx, grpc_endpoint *endpoint,
                               grpc_channel_args *args,
                               gpr_slice_buffer *read_buffer, void *user_data,
                               grpc_error *error) {
-  server_secure_connect *state = user_data;
+  server_secure_connect *connection_state = user_data;
   if (error != GRPC_ERROR_NONE) {
     const char *error_str = grpc_error_string(error);
     gpr_log(GPR_ERROR, "Handshaking failed: %s", error_str);
@@ -139,81 +127,107 @@
     GRPC_ERROR_UNREF(error);
     grpc_channel_args_destroy(args);
     gpr_free(read_buffer);
-    grpc_handshake_manager_shutdown(exec_ctx, state->handshake_mgr);
-    grpc_handshake_manager_destroy(exec_ctx, state->handshake_mgr);
-    state_unref(state->state);
-    gpr_free(state);
+    grpc_handshake_manager_shutdown(exec_ctx, connection_state->handshake_mgr);
+    grpc_handshake_manager_destroy(exec_ctx, connection_state->handshake_mgr);
+    grpc_tcp_server_unref(exec_ctx, connection_state->server_state->tcp);
+    gpr_free(connection_state);
     return;
   }
-  grpc_handshake_manager_destroy(exec_ctx, state->handshake_mgr);
-  state->handshake_mgr = NULL;
+  grpc_handshake_manager_destroy(exec_ctx, connection_state->handshake_mgr);
+  connection_state->handshake_mgr = NULL;
   // TODO(roth, jboeuf): Convert security connector handshaking to use new
   // handshake API, and then move the code from on_secure_handshake_done()
   // into this function.
-  state->args = args;
+  connection_state->args = args;
   grpc_server_security_connector_do_handshake(
-      exec_ctx, state->state->sc, state->acceptor, endpoint, read_buffer,
-      state->deadline, on_secure_handshake_done, state);
+      exec_ctx, connection_state->server_state->sc, connection_state->acceptor,
+      endpoint, read_buffer, connection_state->deadline,
+      on_secure_handshake_done, connection_state);
 }
 
 static void on_accept(grpc_exec_ctx *exec_ctx, void *statep, grpc_endpoint *tcp,
                       grpc_pollset *accepting_pollset,
                       grpc_tcp_server_acceptor *acceptor) {
-  server_secure_connect *state = gpr_malloc(sizeof(*state));
-  state->state = statep;
-  state_ref(state->state);
-  state->accepting_pollset = accepting_pollset;
-  state->acceptor = acceptor;
-  state->handshake_mgr = grpc_handshake_manager_create();
+  server_secure_state *server_state = statep;
+  server_secure_connect *connection_state = NULL;
+  gpr_mu_lock(&server_state->mu);
+  if (server_state->is_shutdown) {
+    gpr_mu_unlock(&server_state->mu);
+    grpc_endpoint_destroy(exec_ctx, tcp);
+    return;
+  }
+  gpr_mu_unlock(&server_state->mu);
+  grpc_tcp_server_ref(server_state->tcp);
+  connection_state = gpr_malloc(sizeof(*connection_state));
+  connection_state->server_state = server_state;
+  connection_state->accepting_pollset = accepting_pollset;
+  connection_state->acceptor = acceptor;
+  connection_state->handshake_mgr = grpc_handshake_manager_create();
   // TODO(roth): We should really get this timeout value from channel
   // args instead of hard-coding it.
-  state->deadline = gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC),
-                                 gpr_time_from_seconds(120, GPR_TIMESPAN));
+  connection_state->deadline = gpr_time_add(
+      gpr_now(GPR_CLOCK_MONOTONIC), gpr_time_from_seconds(120, GPR_TIMESPAN));
   grpc_handshake_manager_do_handshake(
-      exec_ctx, state->handshake_mgr, tcp,
-      grpc_server_get_channel_args(state->state->server), state->deadline,
-      acceptor, on_handshake_done, state);
+      exec_ctx, connection_state->handshake_mgr, tcp,
+      grpc_server_get_channel_args(connection_state->server_state->server),
+      connection_state->deadline, acceptor, on_handshake_done,
+      connection_state);
 }
 
 /* Server callback: start listening on our ports */
-static void start(grpc_exec_ctx *exec_ctx, grpc_server *server, void *statep,
-                  grpc_pollset **pollsets, size_t pollset_count) {
-  server_secure_state *state = statep;
-  grpc_tcp_server_start(exec_ctx, state->tcp, pollsets, pollset_count,
-                        on_accept, state);
+static void server_start_listener(grpc_exec_ctx *exec_ctx, grpc_server *server,
+                                  void *statep, grpc_pollset **pollsets,
+                                  size_t pollset_count) {
+  server_secure_state *server_state = statep;
+  gpr_mu_lock(&server_state->mu);
+  server_state->is_shutdown = false;
+  gpr_mu_unlock(&server_state->mu);
+  grpc_tcp_server_start(exec_ctx, server_state->tcp, pollsets, pollset_count,
+                        on_accept, server_state);
 }
 
-static void destroy_done(grpc_exec_ctx *exec_ctx, void *statep,
-                         grpc_error *error) {
-  server_secure_state *state = statep;
-  if (state->destroy_callback != NULL) {
-    state->destroy_callback->cb(exec_ctx, state->destroy_callback->cb_arg,
-                                GRPC_ERROR_REF(error));
+static void tcp_server_shutdown_complete(grpc_exec_ctx *exec_ctx, void *statep,
+                                         grpc_error *error) {
+  server_secure_state *server_state = statep;
+  /* ensure all threads have unlocked */
+  gpr_mu_lock(&server_state->mu);
+  grpc_closure *destroy_done = server_state->server_destroy_listener_done;
+  GPR_ASSERT(server_state->is_shutdown);
+  gpr_mu_unlock(&server_state->mu);
+  /* clean up */
+  grpc_server_security_connector_shutdown(exec_ctx, server_state->sc);
+
+  /* Flush queued work before a synchronous unref. */
+  grpc_exec_ctx_flush(exec_ctx);
+  GRPC_SECURITY_CONNECTOR_UNREF(&server_state->sc->base, "server");
+  grpc_server_credentials_unref(server_state->creds);
+
+  if (destroy_done != NULL) {
+    destroy_done->cb(exec_ctx, destroy_done->cb_arg, GRPC_ERROR_REF(error));
+    grpc_exec_ctx_flush(exec_ctx);
   }
-  grpc_server_security_connector_shutdown(exec_ctx, state->sc);
-  state_unref(state);
+  gpr_free(server_state);
 }
 
-/* Server callback: destroy the tcp listener (so we don't generate further
-   callbacks) */
-static void destroy(grpc_exec_ctx *exec_ctx, grpc_server *server, void *statep,
-                    grpc_closure *callback) {
-  server_secure_state *state = statep;
+static void server_destroy_listener(grpc_exec_ctx *exec_ctx,
+                                    grpc_server *server, void *statep,
+                                    grpc_closure *callback) {
+  server_secure_state *server_state = statep;
   grpc_tcp_server *tcp;
-  gpr_mu_lock(&state->mu);
-  state->is_shutdown = true;
-  state->destroy_callback = callback;
-  tcp = state->tcp;
-  gpr_mu_unlock(&state->mu);
+  gpr_mu_lock(&server_state->mu);
+  server_state->is_shutdown = true;
+  server_state->server_destroy_listener_done = callback;
+  tcp = server_state->tcp;
+  gpr_mu_unlock(&server_state->mu);
   grpc_tcp_server_shutdown_listeners(exec_ctx, tcp);
-  grpc_tcp_server_unref(exec_ctx, tcp);
+  grpc_tcp_server_unref(exec_ctx, server_state->tcp);
 }
 
 int grpc_server_add_secure_http2_port(grpc_server *server, const char *addr,
                                       grpc_server_credentials *creds) {
   grpc_resolved_addresses *resolved = NULL;
   grpc_tcp_server *tcp = NULL;
-  server_secure_state *state = NULL;
+  server_secure_state *server_state = NULL;
   size_t i;
   size_t count = 0;
   int port_num = -1;
@@ -253,22 +267,22 @@
   if (err != GRPC_ERROR_NONE) {
     goto error;
   }
-  state = gpr_malloc(sizeof(*state));
-  memset(state, 0, sizeof(*state));
-  grpc_closure_init(&state->destroy_closure, destroy_done, state);
-  err = grpc_tcp_server_create(&state->destroy_closure,
+  server_state = gpr_malloc(sizeof(*server_state));
+  memset(server_state, 0, sizeof(*server_state));
+  grpc_closure_init(&server_state->tcp_server_shutdown_complete,
+                    tcp_server_shutdown_complete, server_state);
+  err = grpc_tcp_server_create(&server_state->tcp_server_shutdown_complete,
                                grpc_server_get_channel_args(server), &tcp);
   if (err != GRPC_ERROR_NONE) {
     goto error;
   }
 
-  state->server = server;
-  state->tcp = tcp;
-  state->sc = sc;
-  state->creds = grpc_server_credentials_ref(creds);
-  state->is_shutdown = false;
-  gpr_mu_init(&state->mu);
-  gpr_ref_init(&state->refcount, 1);
+  server_state->server = server;
+  server_state->tcp = tcp;
+  server_state->sc = sc;
+  server_state->creds = grpc_server_credentials_ref(creds);
+  server_state->is_shutdown = true;
+  gpr_mu_init(&server_state->mu);
 
   errors = gpr_malloc(sizeof(*errors) * resolved->naddrs);
   for (i = 0; i < resolved->naddrs; i++) {
@@ -313,7 +327,8 @@
   grpc_resolved_addresses_destroy(resolved);
 
   /* Register with the server only upon success */
-  grpc_server_add_listener(&exec_ctx, server, state, start, destroy);
+  grpc_server_add_listener(&exec_ctx, server, server_state,
+                           server_start_listener, server_destroy_listener);
 
   grpc_exec_ctx_finish(&exec_ctx);
   return port_num;
@@ -334,10 +349,11 @@
     grpc_tcp_server_unref(&exec_ctx, tcp);
   } else {
     if (sc) {
+      grpc_exec_ctx_flush(&exec_ctx);
       GRPC_SECURITY_CONNECTOR_UNREF(&sc->base, "server");
     }
-    if (state) {
-      gpr_free(state);
+    if (server_state) {
+      gpr_free(server_state);
     }
   }
   grpc_exec_ctx_finish(&exec_ctx);
diff --git a/src/core/lib/iomgr/tcp_server.h b/src/core/lib/iomgr/tcp_server.h
index 5a25d39..9a39069 100644
--- a/src/core/lib/iomgr/tcp_server.h
+++ b/src/core/lib/iomgr/tcp_server.h
@@ -101,8 +101,8 @@
 void grpc_tcp_server_shutdown_starting_add(grpc_tcp_server *s,
                                            grpc_closure *shutdown_starting);
 
-/* If the refcount drops to zero, delete s, and call (exec_ctx==NULL) or enqueue
-   a call (exec_ctx!=NULL) to shutdown_complete. */
+/* If the refcount drops to zero, enqueue calls on exec_ctx to
+   shutdown_listeners and delete s. */
 void grpc_tcp_server_unref(grpc_exec_ctx *exec_ctx, grpc_tcp_server *s);
 
 /* Shutdown the fds of listeners. */
diff --git a/src/core/lib/iomgr/tcp_server_posix.c b/src/core/lib/iomgr/tcp_server_posix.c
index 2d3f6cf..73df547 100644
--- a/src/core/lib/iomgr/tcp_server_posix.c
+++ b/src/core/lib/iomgr/tcp_server_posix.c
@@ -191,6 +191,9 @@
 }
 
 static void finish_shutdown(grpc_exec_ctx *exec_ctx, grpc_tcp_server *s) {
+  gpr_mu_lock(&s->mu);
+  GPR_ASSERT(s->shutdown);
+  gpr_mu_unlock(&s->mu);
   if (s->shutdown_complete != NULL) {
     grpc_exec_ctx_sched(exec_ctx, s->shutdown_complete, GRPC_ERROR_NONE, NULL);
   }
@@ -652,6 +655,7 @@
                                        unsigned port_index) {
   unsigned num_fds = 0;
   grpc_tcp_listener *sp;
+  gpr_mu_lock(&s->mu);
   for (sp = s->head; sp && port_index != 0; sp = sp->next) {
     if (!sp->is_sibling) {
       --port_index;
@@ -659,12 +663,15 @@
   }
   for (; sp; sp = sp->sibling, ++num_fds)
     ;
+  gpr_mu_unlock(&s->mu);
   return num_fds;
 }
 
 int grpc_tcp_server_port_fd(grpc_tcp_server *s, unsigned port_index,
                             unsigned fd_index) {
   grpc_tcp_listener *sp;
+  int fd;
+  gpr_mu_lock(&s->mu);
   for (sp = s->head; sp && port_index != 0; sp = sp->next) {
     if (!sp->is_sibling) {
       --port_index;
@@ -673,10 +680,12 @@
   for (; sp && fd_index != 0; sp = sp->sibling, --fd_index)
     ;
   if (sp) {
-    return sp->fd;
+    fd = sp->fd;
   } else {
-    return -1;
+    fd = -1;
   }
+  gpr_mu_unlock(&s->mu);
+  return fd;
 }
 
 void grpc_tcp_server_start(grpc_exec_ctx *exec_ctx, grpc_tcp_server *s,
@@ -722,7 +731,7 @@
 }
 
 grpc_tcp_server *grpc_tcp_server_ref(grpc_tcp_server *s) {
-  gpr_ref(&s->refs);
+  gpr_ref_non_zero(&s->refs);
   return s;
 }
 
@@ -736,19 +745,11 @@
 
 void grpc_tcp_server_unref(grpc_exec_ctx *exec_ctx, grpc_tcp_server *s) {
   if (gpr_unref(&s->refs)) {
-    /* Complete shutdown_starting work before destroying. */
-    grpc_exec_ctx local_exec_ctx = GRPC_EXEC_CTX_INIT;
+    grpc_tcp_server_shutdown_listeners(exec_ctx, s);
     gpr_mu_lock(&s->mu);
-    grpc_exec_ctx_enqueue_list(&local_exec_ctx, &s->shutdown_starting, NULL);
+    grpc_exec_ctx_enqueue_list(exec_ctx, &s->shutdown_starting, NULL);
     gpr_mu_unlock(&s->mu);
-    if (exec_ctx == NULL) {
-      grpc_exec_ctx_flush(&local_exec_ctx);
-      tcp_server_destroy(&local_exec_ctx, s);
-      grpc_exec_ctx_finish(&local_exec_ctx);
-    } else {
-      grpc_exec_ctx_finish(&local_exec_ctx);
-      tcp_server_destroy(exec_ctx, s);
-    }
+    tcp_server_destroy(exec_ctx, s);
   }
 }
 
diff --git a/src/core/lib/iomgr/tcp_server_windows.c b/src/core/lib/iomgr/tcp_server_windows.c
index 1b125e7..4ff0560 100644
--- a/src/core/lib/iomgr/tcp_server_windows.c
+++ b/src/core/lib/iomgr/tcp_server_windows.c
@@ -139,7 +139,7 @@
 }
 
 grpc_tcp_server *grpc_tcp_server_ref(grpc_tcp_server *s) {
-  gpr_ref(&s->refs);
+  gpr_ref_non_zero(&s->refs);
   return s;
 }
 
@@ -174,19 +174,11 @@
 
 void grpc_tcp_server_unref(grpc_exec_ctx *exec_ctx, grpc_tcp_server *s) {
   if (gpr_unref(&s->refs)) {
-    /* Complete shutdown_starting work before destroying. */
-    grpc_exec_ctx local_exec_ctx = GRPC_EXEC_CTX_INIT;
+    grpc_tcp_server_shutdown_listeners(exec_ctx, s);
     gpr_mu_lock(&s->mu);
-    grpc_exec_ctx_enqueue_list(&local_exec_ctx, &s->shutdown_starting, NULL);
+    grpc_exec_ctx_enqueue_list(exec_ctx, &s->shutdown_starting, NULL);
     gpr_mu_unlock(&s->mu);
-    if (exec_ctx == NULL) {
-      grpc_exec_ctx_flush(&local_exec_ctx);
-      tcp_server_destroy(&local_exec_ctx, s);
-      grpc_exec_ctx_finish(&local_exec_ctx);
-    } else {
-      grpc_exec_ctx_finish(&local_exec_ctx);
-      tcp_server_destroy(exec_ctx, s);
-    }
+    tcp_server_destroy(exec_ctx, s);
   }
 }
 
diff --git a/test/core/iomgr/tcp_server_posix_test.c b/test/core/iomgr/tcp_server_posix_test.c
index 6e2d1d0..6b1dd42 100644
--- a/test/core/iomgr/tcp_server_posix_test.c
+++ b/test/core/iomgr/tcp_server_posix_test.c
@@ -314,11 +314,10 @@
   GPR_ASSERT(grpc_tcp_server_port_fd(s, 0, 0) >= 0);
 
   grpc_tcp_server_unref(&exec_ctx, s);
+  grpc_exec_ctx_finish(&exec_ctx);
 
   /* Weak ref lost. */
   GPR_ASSERT(weak_ref.server == NULL);
-
-  grpc_exec_ctx_finish(&exec_ctx);
 }
 
 static void destroy_pollset(grpc_exec_ctx *exec_ctx, void *p,