Issue #7105: Make WeakKeyDictionary and WeakValueDictionary robust against
the destruction of weakref'ed objects while iterating.
diff --git a/Lib/_weakrefset.py b/Lib/_weakrefset.py
index addc7af..3de3bda 100644
--- a/Lib/_weakrefset.py
+++ b/Lib/_weakrefset.py
@@ -6,22 +6,61 @@
 
 __all__ = ['WeakSet']
 
+
+class _IterationGuard:
+    # This context manager registers itself in the current iterators of the
+    # weak container, such as to delay all removals until the context manager
+    # exits.
+    # This technique should be relatively thread-safe (since sets are).
+
+    def __init__(self, weakcontainer):
+        # Don't create cycles
+        self.weakcontainer = ref(weakcontainer)
+
+    def __enter__(self):
+        w = self.weakcontainer()
+        if w is not None:
+            w._iterating.add(self)
+        return self
+
+    def __exit__(self, e, t, b):
+        w = self.weakcontainer()
+        if w is not None:
+            s = w._iterating
+            s.remove(self)
+            if not s:
+                w._commit_removals()
+
+
 class WeakSet:
     def __init__(self, data=None):
         self.data = set()
         def _remove(item, selfref=ref(self)):
             self = selfref()
             if self is not None:
-                self.data.discard(item)
+                if self._iterating:
+                    self._pending_removals.append(item)
+                else:
+                    self.data.discard(item)
         self._remove = _remove
+        # A list of keys to be removed
+        self._pending_removals = []
+        self._iterating = set()
         if data is not None:
             self.update(data)
 
+    def _commit_removals(self):
+        l = self._pending_removals
+        discard = self.data.discard
+        while l:
+            discard(l.pop())
+
     def __iter__(self):
-        for itemref in self.data:
-            item = itemref()
-            if item is not None:
-                yield item
+        with _IterationGuard(self):
+            for itemref in self.data:
+                item = itemref()
+                if item is not None:
+                    yield item
 
     def __len__(self):
         return sum(x() is not None for x in self.data)
@@ -34,15 +73,21 @@
                 getattr(self, '__dict__', None))
 
     def add(self, item):
+        if self._pending_removals:
+            self._commit_removals()
         self.data.add(ref(item, self._remove))
 
     def clear(self):
+        if self._pending_removals:
+            self._commit_removals()
         self.data.clear()
 
     def copy(self):
         return self.__class__(self)
 
     def pop(self):
+        if self._pending_removals:
+            self._commit_removals()
         while True:
             try:
                 itemref = self.data.pop()
@@ -53,17 +98,24 @@
                 return item
 
     def remove(self, item):
+        if self._pending_removals:
+            self._commit_removals()
         self.data.remove(ref(item))
 
     def discard(self, item):
+        if self._pending_removals:
+            self._commit_removals()
         self.data.discard(ref(item))
 
     def update(self, other):
+        if self._pending_removals:
+            self._commit_removals()
         if isinstance(other, self.__class__):
             self.data.update(other.data)
         else:
             for element in other:
                 self.add(element)
+
     def __ior__(self, other):
         self.update(other)
         return self
@@ -82,11 +134,15 @@
     __sub__ = difference
 
     def difference_update(self, other):
+        if self._pending_removals:
+            self._commit_removals()
         if self is other:
             self.data.clear()
         else:
             self.data.difference_update(ref(item) for item in other)
     def __isub__(self, other):
+        if self._pending_removals:
+            self._commit_removals()
         if self is other:
             self.data.clear()
         else:
@@ -98,8 +154,12 @@
     __and__ = intersection
 
     def intersection_update(self, other):
+        if self._pending_removals:
+            self._commit_removals()
         self.data.intersection_update(ref(item) for item in other)
     def __iand__(self, other):
+        if self._pending_removals:
+            self._commit_removals()
         self.data.intersection_update(ref(item) for item in other)
         return self
 
@@ -127,11 +187,15 @@
     __xor__ = symmetric_difference
 
     def symmetric_difference_update(self, other):
+        if self._pending_removals:
+            self._commit_removals()
         if self is other:
             self.data.clear()
         else:
             self.data.symmetric_difference_update(ref(item) for item in other)
     def __ixor__(self, other):
+        if self._pending_removals:
+            self._commit_removals()
         if self is other:
             self.data.clear()
         else: