Significantly refactor the polling island locking and refcounting code
diff --git a/src/core/lib/iomgr/ev_epoll_linux.c b/src/core/lib/iomgr/ev_epoll_linux.c
index ed2c494..7228888 100644
--- a/src/core/lib/iomgr/ev_epoll_linux.c
+++ b/src/core/lib/iomgr/ev_epoll_linux.c
@@ -140,18 +140,40 @@
 #define CLOSURE_READY ((grpc_closure *)1)
 
 /*******************************************************************************
- * Polling-island Declarations
+ * Polling island Declarations
  */
-/* TODO: sree: Consider making ref_cnt and merged_to to gpr_atm - This would
- * significantly reduce the number of mutex acquisition calls. */
+
+// #define GRPC_PI_REF_COUNT_DEBUG
+#ifdef GRPC_PI_REF_COUNT_DEBUG
+
+#define PI_ADD_REF(p, r) pi_add_ref_dbg((p), 1, (r), __FILE__, __LINE__)
+#define PI_UNREF(p, r) pi_unref_dbg((p), 1, (r), __FILE__, __LINE__)
+
+#else /* defined(GRPC_PI_REF_COUNT_DEBUG) */
+
+#define PI_ADD_REF(p, r) pi_add_ref((p), 1)
+#define PI_UNREF(p, r) pi_unref((p), 1)
+
+#endif /* !defined(GPRC_PI_REF_COUNT_DEBUG) */
+
 typedef struct polling_island {
   gpr_mu mu;
-  int ref_cnt;
+  /* Ref count. Use PI_ADD_REF() and PI_UNREF() macros to increment/decrement
+     the refcount.
+     Once the ref count becomes zero, this structure is destroyed which means
+     we should ensure that there is never a scenario where a PI_ADD_REF() is
+     racing with a PI_UNREF() that just made the ref_count zero. */
+  gpr_atm ref_count;
 
-  /* Points to the polling_island this merged into.
-   * If merged_to is not NULL, all the remaining fields (except mu and ref_cnt)
-   * are invalid and must be ignored */
-  struct polling_island *merged_to;
+  /* Pointer to the polling_island this merged into.
+   * merged_to value is only set once in polling_island's lifetime (and that too
+   * only if the island is merged with another island). Because of this, we can
+   * use gpr_atm type here so that we can do atomic access on this and reduce
+   * lock contention on 'mu' mutex.
+   *
+   * Note that if this field is not NULL (i.e not 0), all the remaining fields
+   * (except mu and ref_count) are invalid and must be ignored. */
+  gpr_atm merged_to;
 
   /* The fd of the underlying epoll set */
   int epoll_fd;
@@ -236,6 +258,8 @@
 static gpr_mu g_pi_freelist_mu;
 static polling_island *g_pi_freelist = NULL;
 
+static void polling_island_delete(); /* Forward declaration */
+
 #ifdef GRPC_TSAN
 /* Currently TSAN may incorrectly flag data races between epoll_ctl and
    epoll_wait for any grpc_fd structs that are added to the epoll set via
@@ -247,6 +271,51 @@
 gpr_atm g_epoll_sync;
 #endif /* defined(GRPC_TSAN) */
 
+#ifdef GRPC_PI_REF_COUNT_DEBUG
+long pi_add_ref(polling_island *pi, int ref_cnt);
+long pi_unref(polling_island *pi, int ref_cnt);
+
+void pi_add_ref_dbg(polling_island *pi, int ref_cnt, char *reason, char *file,
+                    int line) {
+  long old_cnt = pi_add_ref(pi, ref_cnt);
+  gpr_log(GPR_DEBUG, "Add ref pi: %p, old:%ld -> new:%ld (%s) - (%s, %d)",
+          (void *)pi, old_cnt, (old_cnt + ref_cnt), reason, file, line);
+}
+
+void pi_unref_dbg(polling_island *pi, int ref_cnt, char *reason, char *file,
+                  int line) {
+  long old_cnt = pi_unref(pi, ref_cnt);
+  gpr_log(GPR_DEBUG, "Unref pi: %p, old:%ld -> new:%ld (%s) - (%s, %d)",
+          (void *)pi, old_cnt, (old_cnt - ref_cnt), reason, file, line);
+}
+#endif
+
+long pi_add_ref(polling_island *pi, int ref_cnt) {
+  return gpr_atm_no_barrier_fetch_add(&pi->ref_count, ref_cnt);
+}
+
+long pi_unref(polling_island *pi, int ref_cnt) {
+  long old_cnt = gpr_atm_no_barrier_fetch_add(&pi->ref_count, -ref_cnt);
+
+  /* If ref count went to zero, delete the polling island. Note that this need
+     not be done under a lock. Once the ref count goes to zero, we are
+     guaranteed that no one else holds a reference to the polling island (and
+     that there is no racing pi_add_ref() call either.
+
+     Also, if we are deleting the polling island and the merged_to field is
+     non-empty, we should remove a ref to the merged_to polling island
+   */
+  if (old_cnt == ref_cnt) {
+    polling_island *next = (polling_island *)gpr_atm_acq_load(&pi->merged_to);
+    polling_island_delete(pi);
+    if (next != NULL) {
+      PI_UNREF(next, "pi_delete"); /* Recursive call */
+    }
+  }
+
+  return old_cnt;
+}
+
 /* The caller is expected to hold pi->mu lock before calling this function */
 static void polling_island_add_fds_locked(polling_island *pi, grpc_fd **fds,
                                           size_t fd_count, bool add_fd_refs) {
@@ -355,8 +424,7 @@
   }
 }
 
-static polling_island *polling_island_create(grpc_fd *initial_fd,
-                                             int initial_ref_cnt) {
+static polling_island *polling_island_create(grpc_fd *initial_fd) {
   polling_island *pi = NULL;
 
   /* Try to get one from the polling island freelist */
@@ -377,6 +445,9 @@
     pi->fds = NULL;
   }
 
+  gpr_atm_no_barrier_store(&pi->ref_count, 0);
+  gpr_atm_no_barrier_store(&pi->merged_to, NULL);
+
   pi->epoll_fd = epoll_create1(EPOLL_CLOEXEC);
 
   if (pi->epoll_fd < 0) {
@@ -387,14 +458,12 @@
 
   polling_island_add_wakeup_fd_locked(pi, &grpc_global_wakeup_fd);
 
-  pi->ref_cnt = initial_ref_cnt;
-  pi->merged_to = NULL;
   pi->next_free = NULL;
 
   if (initial_fd != NULL) {
-    /* It is not really needed to get the pi->mu lock here. If this is a newly
-       created polling island (or one that we got from the freelist), no one
-       else would be holding a lock to it anyway */
+    /* Lock the polling island here just in case we got this structure from the
+       freelist and the polling island lock was not released yet (by the code
+       that adds the polling island to the freelist) */
     gpr_mu_lock(&pi->mu);
     polling_island_add_fds_locked(pi, &initial_fd, 1, true);
     gpr_mu_unlock(&pi->mu);
@@ -404,140 +473,136 @@
 }
 
 static void polling_island_delete(polling_island *pi) {
-  GPR_ASSERT(pi->ref_cnt == 0);
   GPR_ASSERT(pi->fd_cnt == 0);
 
+  gpr_atm_rel_store(&pi->merged_to, NULL);
+
   close(pi->epoll_fd);
   pi->epoll_fd = -1;
 
-  pi->merged_to = NULL;
-
   gpr_mu_lock(&g_pi_freelist_mu);
   pi->next_free = g_pi_freelist;
   g_pi_freelist = pi;
   gpr_mu_unlock(&g_pi_freelist_mu);
 }
 
-void polling_island_unref_and_unlock(polling_island *pi, int unref_by) {
-  pi->ref_cnt -= unref_by;
-  int ref_cnt = pi->ref_cnt;
-  GPR_ASSERT(ref_cnt >= 0);
-
-  gpr_mu_unlock(&pi->mu);
-
-  if (ref_cnt == 0) {
-    polling_island_delete(pi);
-  }
-}
-
-polling_island *polling_island_update_and_lock(polling_island *pi, int unref_by,
-                                               int add_ref_by) {
+/* Gets the lock on the *latest* polling island i.e the last polling island in
+   the linked list (linked by 'merged_to' link). Call gpr_mu_unlock on the
+   returned polling island's mu.
+   Usage: To lock/unlock polling island "pi", do the following:
+      polling_island *pi_latest = polling_island_lock(pi);
+      ...
+      ... critical section ..
+      ...
+      gpr_mu_unlock(&pi_latest->mu); //NOTE: use pi_latest->mu. NOT pi->mu */
+polling_island *polling_island_lock(polling_island *pi) {
   polling_island *next = NULL;
-  gpr_mu_lock(&pi->mu);
-  while (pi->merged_to != NULL) {
-    next = pi->merged_to;
-    polling_island_unref_and_unlock(pi, unref_by);
+  while (true) {
+    next = (polling_island *)gpr_atm_acq_load(&pi->merged_to);
+    if (next == NULL) {
+      /* pi is the last node in the linked list. Get the lock and check again
+         (under the pi->mu lock) that pi is still the last node (because a merge
+         may have happend after the (next == NULL) check above and before
+         getting the pi->mu lock.
+         If pi is the last node, we are done. If not, unlock and continue
+         traversing the list */
+      gpr_mu_lock(&pi->mu);
+      next = (polling_island *)gpr_atm_acq_load(&pi->merged_to);
+      if (next == NULL) {
+        break;
+      }
+      gpr_mu_unlock(&pi->mu);
+    }
+
     pi = next;
-    gpr_mu_lock(&pi->mu);
   }
 
-  pi->ref_cnt += add_ref_by;
   return pi;
 }
 
-void polling_island_pair_update_and_lock(polling_island **p,
-                                         polling_island **q) {
+/* Gets the lock on the *latest* polling islands pointed by *p and *q.
+   This function is needed because calling the following block of code to obtain
+   locks on polling islands (*p and *q) is prone to deadlocks.
+     {
+       polling_island_lock(*p);
+       polling_island_lock(*q);
+     }
+
+   Usage/exmaple:
+     polling_island *p1;
+     polling_island *p2;
+     ..
+     polling_island_lock_pair(&p1, &p2);
+     ..
+     .. Critical section with both p1 and p2 locked
+     ..
+     // Release locks
+     // **IMPORTANT**: Make sure you check p1 == p2 AFTER the function
+     // polling_island_lock_pair() was called and if so, release the lock only
+     // once. Note: Even if p1 != p2 beforec calling polling_island_lock_pair(),
+     // they might be after the function returns:
+     if (p1 == p2) {
+       gpr_mu_unlock(&p1->mu)
+     } else {
+       gpr_mu_unlock(&p1->mu);
+       gpr_mu_unlock(&p2->mu);
+     }
+
+*/
+void polling_island_lock_pair(polling_island **p, polling_island **q) {
   polling_island *pi_1 = *p;
   polling_island *pi_2 = *q;
-  polling_island *temp = NULL;
-  bool pi_1_locked = false;
-  bool pi_2_locked = false;
-  int num_swaps = 0;
+  polling_island *next_1 = NULL;
+  polling_island *next_2 = NULL;
 
-  /* Loop until either pi_1 == pi_2 or until we acquired locks on both pi_1
-     and pi_2 */
-  while (pi_1 != pi_2 && !(pi_1_locked && pi_2_locked)) {
-    /* The following assertions are true at this point:
-       - pi_1 != pi_2  (else, the while loop would have exited)
-       - pi_1 MAY be locked
-       - pi_2 is NOT locked */
-
-    /* To maintain lock order consistency, always lock polling_island node with
-       lower address first.
-       First, make sure pi_1 < pi_2 before proceeding any further. If it turns
-       out that pi_1 > pi_2, unlock pi_1 if locked (because pi_2 is not locked
-       at this point and having pi_1 locked would violate the lock order) and
-       swap pi_1 and pi_2 so that pi_1 becomes less than pi_2 */
-    if (pi_1 > pi_2) {
-      if (pi_1_locked) {
-        gpr_mu_unlock(&pi_1->mu);
-        pi_1_locked = false;
-      }
-
-      GPR_SWAP(polling_island *, pi_1, pi_2);
-      num_swaps++;
+  /* The algorithm is simple:
+      - Go to the last polling islands in the linked lists *pi_1 and *pi_2 (and
+        keep updating pi_1 and pi_2)
+      - Then obtain locks on the islands by following a lock order rule of
+        locking polling_island with lower address first
+           Special case: Before obtaining the locks, check if pi_1 and pi_2 are
+           pointing to the same island. If that is the case, we can just call
+           polling_island_lock()
+      - After obtaining both the locks, double check that the polling islands
+        are still the last polling islands in their respective linked lists
+        (this is because there might have been polling island merges before
+        we got the lock)
+      - If the polling islands are the last islands, we are done. If not,
+        release the locks and continue the process from the first step */
+  while (true) {
+    next_1 = (polling_island *)gpr_atm_acq_load(&pi_1->merged_to);
+    while (next_1 != NULL) {
+      pi_1 = next_1;
+      next_1 = (polling_island *)gpr_atm_acq_load(&pi_1->merged_to);
     }
 
-    /* The following assertions are true at this point:
-       - pi_1 != pi_2
-       - pi_1 < pi_2  (address of pi_1 is less than that of pi_2)
-       - pi_1 MAYBE locked
-       - pi_2 is NOT locked */
+    next_2 = (polling_island *)gpr_atm_acq_load(&pi_2->merged_to);
+    while (next_2 != NULL) {
+      pi_2 = next_2;
+      next_2 = (polling_island *)gpr_atm_acq_load(&pi_2->merged_to);
+    }
 
-    /* Lock pi_1 (if pi_1 is pointing to the terminal node in the list) */
-    if (!pi_1_locked) {
+    if (pi_1 == pi_2) {
+      pi_1 = pi_2 = polling_island_lock(pi_1);
+      break;
+    }
+
+    if (pi_1 < pi_2) {
       gpr_mu_lock(&pi_1->mu);
-      pi_1_locked = true;
-
-      /* If pi_1 is not terminal node (i.e pi_1->merged_to != NULL), we are not
-         done locking this polling_island yet. Release the lock on this node and
-         advance pi_1 to the next node in the list; and go to the beginning of
-         the loop (we can't proceed to locking pi_2 unless we locked pi_1 first)
-         */
-      if (pi_1->merged_to != NULL) {
-        temp = pi_1->merged_to;
-        polling_island_unref_and_unlock(pi_1, 1);
-        pi_1 = temp;
-        pi_1_locked = false;
-
-        continue;
-      }
+      gpr_mu_lock(&pi_2->mu);
+    } else {
+      gpr_mu_lock(&pi_2->mu);
+      gpr_mu_lock(&pi_1->mu);
     }
 
-    /* The following assertions are true at this point:
-       - pi_1 is locked
-       - pi_2 is unlocked
-       - pi_1 != pi_2 */
-
-    gpr_mu_lock(&pi_2->mu);
-    pi_2_locked = true;
-
-    /* If pi_2 is not terminal node, we are not done locking this polling_island
-       yet. Release the lock and update pi_2 to the next node in the list */
-    if (pi_2->merged_to != NULL) {
-      temp = pi_2->merged_to;
-      polling_island_unref_and_unlock(pi_2, 1);
-      pi_2 = temp;
-      pi_2_locked = false;
+    next_1 = (polling_island *)gpr_atm_acq_load(&pi_1->merged_to);
+    next_2 = (polling_island *)gpr_atm_acq_load(&pi_2->merged_to);
+    if (next_1 == NULL && next_2 == NULL) {
+      break;
     }
-  }
 
-  /* At this point, either pi_1 == pi_2 AND/OR we got both locks */
-  if (pi_1 == pi_2) {
-    /* We may or may not have gotten the lock. If we didn't, walk the rest of
-      the polling_island list and get the lock */
-    GPR_ASSERT(pi_1_locked || (!pi_1_locked && !pi_2_locked));
-    if (!pi_1_locked) {
-      pi_1 = pi_2 = polling_island_update_and_lock(pi_1, 2, 0);
-    }
-  } else {
-    GPR_ASSERT(pi_1_locked && pi_2_locked);
-    /* If we swapped pi_1 and pi_2 odd number of times, do one more swap so that
-       pi_1 and pi_2 point to the same polling_island lists they started off
-       with at the beginning of this function (i.e *p and *q respectively) */
-    if (num_swaps % 2 > 0) {
-      GPR_SWAP(polling_island *, pi_1, pi_2);
-    }
+    gpr_mu_unlock(&pi_1->mu);
+    gpr_mu_unlock(&pi_2->mu);
   }
 
   *p = pi_1;
@@ -546,7 +611,7 @@
 
 polling_island *polling_island_merge(polling_island *p, polling_island *q) {
   /* Get locks on both the polling islands */
-  polling_island_pair_update_and_lock(&p, &q);
+  polling_island_lock_pair(&p, &q);
 
   if (p == q) {
     /* Nothing needs to be done here */
@@ -568,15 +633,14 @@
   /* Wakeup all the pollers (if any) on p so that they can pickup this change */
   polling_island_add_wakeup_fd_locked(p, &polling_island_wakeup_fd);
 
-  p->merged_to = q;
+  /* Add the 'merged_to' link from p --> q */
+  gpr_atm_rel_store(&p->merged_to, q);
+  PI_ADD_REF(q, "pi_merge"); /* To account for the new incoming ref from p */
 
-  /* - The merged polling island (i.e q) inherits all the ref counts of the
-       island merging with it (i.e p)
-     - The island p will lose a ref count */
-  q->ref_cnt += p->ref_cnt;
-  polling_island_unref_and_unlock(p, 1); /* Decrement refcount */
-  polling_island_unref_and_unlock(q, 0); /* Just Unlock. Don't decrement ref */
+  gpr_mu_unlock(&p->mu);
+  gpr_mu_unlock(&q->mu);
 
+  /* Return the merged polling island */
   return q;
 }
 
@@ -667,6 +731,7 @@
     fd->freelist_next = fd_freelist;
     fd_freelist = fd;
     grpc_iomgr_unregister_object(&fd->iomgr_object);
+
     gpr_mu_unlock(&fd_freelist_mu);
   } else {
     GPR_ASSERT(old > n);
@@ -785,16 +850,20 @@
   REF_BY(fd, 1, reason);
 
   /* Remove the fd from the polling island:
-     - Update the fd->polling_island to point to the latest polling island
-     - Remove the fd from the polling island.
-     - Remove a ref to the polling island and set fd->polling_island to NULL */
+     - Get a lock on the latest polling island (i.e the last island in the
+       linked list pointed by fd->polling_island). This is the island that
+       would actually contain the fd
+     - Remove the fd from the latest polling island
+     - Unlock the latest polling island
+     - Set fd->polling_island to NULL (but remove the ref on the polling island
+       before doing this.) */
   gpr_mu_lock(&fd->pi_mu);
   if (fd->polling_island != NULL) {
-    fd->polling_island =
-        polling_island_update_and_lock(fd->polling_island, 1, 0);
-    polling_island_remove_fd_locked(fd->polling_island, fd, is_fd_closed);
+    polling_island *pi_latest = polling_island_lock(fd->polling_island);
+    polling_island_remove_fd_locked(pi_latest, fd, is_fd_closed);
+    gpr_mu_unlock(&pi_latest->mu);
 
-    polling_island_unref_and_unlock(fd->polling_island, 1);
+    PI_UNREF(fd->polling_island, "fd_orphan");
     fd->polling_island = NULL;
   }
   gpr_mu_unlock(&fd->pi_mu);
@@ -1050,17 +1119,13 @@
   gpr_mu_unlock(&fd->mu);
 }
 
-/* Release the reference to pollset->polling_island and set it to NULL.
-   pollset->mu must be held */
-static void pollset_release_polling_island_locked(grpc_pollset *pollset) {
-  gpr_mu_lock(&pollset->pi_mu);
-  if (pollset->polling_island) {
-    pollset->polling_island =
-        polling_island_update_and_lock(pollset->polling_island, 1, 0);
-    polling_island_unref_and_unlock(pollset->polling_island, 1);
-    pollset->polling_island = NULL;
+static void pollset_release_polling_island(grpc_pollset *ps, char *reason) {
+  gpr_mu_lock(&ps->pi_mu);
+  if (ps->polling_island != NULL) {
+    PI_UNREF(ps->polling_island, reason);
   }
-  gpr_mu_unlock(&pollset->pi_mu);
+  ps->polling_island = NULL;
+  gpr_mu_unlock(&ps->pi_mu);
 }
 
 static void finish_shutdown_locked(grpc_exec_ctx *exec_ctx,
@@ -1069,8 +1134,9 @@
   GPR_ASSERT(!pollset_has_workers(pollset));
 
   pollset->finish_shutdown_called = true;
-  pollset_release_polling_island_locked(pollset);
 
+  /* Release the ref and set pollset->polling_island to NULL */
+  pollset_release_polling_island(pollset, "ps_shutdown");
   grpc_exec_ctx_enqueue(exec_ctx, pollset->shutdown_done, true, NULL);
 }
 
@@ -1110,7 +1176,7 @@
   pollset->finish_shutdown_called = false;
   pollset->kicked_without_pollers = false;
   pollset->shutdown_done = NULL;
-  pollset_release_polling_island_locked(pollset);
+  pollset_release_polling_island(pollset, "ps_reset");
 }
 
 #define GRPC_EPOLL_MAX_EVENTS 1000
@@ -1124,28 +1190,37 @@
   GPR_TIMER_BEGIN("pollset_work_and_unlock", 0);
 
   /* We need to get the epoll_fd to wait on. The epoll_fd is in inside the
-     polling island pointed by pollset->polling_island.
+     latest polling island pointed by pollset->polling_island.
      Acquire the following locks:
      - pollset->mu (which we already have)
      - pollset->pi_mu
-     - pollset->polling_island->mu (call polling_island_update_and_lock())*/
+     - pollset->polling_island lock  */
   gpr_mu_lock(&pollset->pi_mu);
 
-  pi = pollset->polling_island;
-  if (pi == NULL) {
-    pi = polling_island_create(NULL, 1);
+  if (pollset->polling_island == NULL) {
+    pollset->polling_island = polling_island_create(NULL);
+    PI_ADD_REF(pollset->polling_island, "ps");
   }
 
-  /* In addition to locking the polling island, add a ref so that the island
-     does not get destroyed (which means the epoll_fd won't be closed) while
-     we are are doing an epoll_wait() on the epoll_fd */
-  pi = polling_island_update_and_lock(pi, 1, 1);
+  pi = polling_island_lock(pollset->polling_island);
   epoll_fd = pi->epoll_fd;
 
-  /* Update the pollset->polling_island */
-  pollset->polling_island = pi;
+  /* Update the pollset->polling_island since the island being pointed by
+     pollset->polling_island may not be the latest (i.e pi) */
+  if (pollset->polling_island != pi) {
+    /* Always do PI_ADD_REF before PI_UNREF because PI_UNREF may cause the
+       polling island to be deleted */
+    PI_ADD_REF(pi, "ps");
+    PI_UNREF(pollset->polling_island, "ps");
+    pollset->polling_island = pi;
+  }
 
-  polling_island_unref_and_unlock(pollset->polling_island, 0); /* Keep the ref*/
+  /* Add an extra ref so that the island does not get destroyed (which means
+     the epoll_fd won't be closed) while we are are doing an epoll_wait() on the
+     epoll_fd */
+  PI_ADD_REF(pi, "ps_work");
+
+  gpr_mu_unlock(&pi->mu);
   gpr_mu_unlock(&pollset->pi_mu);
   gpr_mu_unlock(&pollset->mu);
 
@@ -1193,14 +1268,12 @@
 
   GPR_ASSERT(pi != NULL);
 
-  /* Before leaving, release the extra ref we added to the polling island */
-  /* It is important to note that at this point 'pi' may not be the same as
-   * pollset->polling_island. This is because pollset->polling_island pointer
-   * gets updated whenever the underlying polling island is merged with another
-   * island and while we are doing epoll_wait() above, the polling island may
-   * have been merged */
-  pi = polling_island_update_and_lock(pi, 1, 0); /* No new ref added */
-  polling_island_unref_and_unlock(pi, 1);
+  /* Before leaving, release the extra ref we added to the polling island. It
+     is important to use "pi" here (i.e our old copy of pollset->polling_island
+     that we got before releasing the polling island lock). This is because
+     pollset->polling_island pointer might get udpated in other parts of the
+     code when there is an island merge while we are doing epoll_wait() above */
+  PI_UNREF(pi, "ps_work");
 
   GPR_TIMER_END("pollset_work_and_unlock", 0);
 }
@@ -1297,20 +1370,34 @@
   if (fd->polling_island == pollset->polling_island) {
     pi_new = fd->polling_island;
     if (pi_new == NULL) {
-      pi_new = polling_island_create(fd, 2);
+      pi_new = polling_island_create(fd);
     }
   } else if (fd->polling_island == NULL) {
-    pi_new = polling_island_update_and_lock(pollset->polling_island, 1, 1);
-    polling_island_add_fds_locked(pollset->polling_island, &fd, 1, true);
+    pi_new = polling_island_lock(pollset->polling_island);
+    polling_island_add_fds_locked(pi_new, &fd, 1, true);
     gpr_mu_unlock(&pi_new->mu);
   } else if (pollset->polling_island == NULL) {
-    pi_new = polling_island_update_and_lock(fd->polling_island, 1, 1);
+    pi_new = polling_island_lock(fd->polling_island);
     gpr_mu_unlock(&pi_new->mu);
   } else {
     pi_new = polling_island_merge(fd->polling_island, pollset->polling_island);
   }
 
-  fd->polling_island = pollset->polling_island = pi_new;
+  if (fd->polling_island != pi_new) {
+    PI_ADD_REF(pi_new, "fd");
+    if (fd->polling_island != NULL) {
+      PI_UNREF(fd->polling_island, "fd");
+    }
+    fd->polling_island = pi_new;
+  }
+
+  if (pollset->polling_island != pi_new) {
+    PI_ADD_REF(pi_new, "ps");
+    if (pollset->polling_island != NULL) {
+      PI_UNREF(pollset->polling_island, "ps");
+    }
+    pollset->polling_island = pi_new;
+  }
 
   gpr_mu_unlock(&fd->pi_mu);
   gpr_mu_unlock(&pollset->pi_mu);
@@ -1481,28 +1568,19 @@
   return pi;
 }
 
-static polling_island *get_polling_island(polling_island *p) {
-  if (p == NULL) {
-    return NULL;
-  }
-
-  polling_island *next;
-  gpr_mu_lock(&p->mu);
-  while (p->merged_to != NULL) {
-    next = p->merged_to;
-    gpr_mu_unlock(&p->mu);
-    p = next;
-    gpr_mu_lock(&p->mu);
-  }
-  gpr_mu_unlock(&p->mu);
-
-  return p;
-}
-
 bool grpc_are_polling_islands_equal(void *p, void *q) {
-  p = get_polling_island(p);
-  q = get_polling_island(q);
-  return p == q;
+  polling_island *p1 = p;
+  polling_island *p2 = q;
+
+  polling_island_lock_pair(&p1, &p2);
+  if (p1 == p2) {
+    gpr_mu_unlock(&p1->mu);
+  } else {
+    gpr_mu_unlock(&p1->mu);
+    gpr_mu_unlock(&p2->mu);
+  }
+
+  return p1 == p2;
 }
 
 /*******************************************************************************