bpo-30604: clean up co_extra support (#2144)

bpo-30604: port fix from 3.6 dropping binary compatibility tweaks
diff --git a/Include/pystate.h b/Include/pystate.h
index a5bbb25..edfb08b 100644
--- a/Include/pystate.h
+++ b/Include/pystate.h
@@ -74,6 +74,10 @@
     PyObject *import_func;
     /* Initialized to PyEval_EvalFrameDefault(). */
     _PyFrameEvalFunction eval_frame;
+
+    Py_ssize_t co_extra_user_count;
+    freefunc co_extra_freefuncs[MAX_CO_EXTRA_USERS];
+
 #ifdef HAVE_FORK
     PyObject *before_forkers;
     PyObject *after_forkers_parent;
@@ -173,9 +177,6 @@
     PyObject *coroutine_wrapper;
     int in_coroutine_wrapper;
 
-    Py_ssize_t co_extra_user_count;
-    freefunc co_extra_freefuncs[MAX_CO_EXTRA_USERS];
-
     PyObject *async_gen_firstiter;
     PyObject *async_gen_finalizer;
 
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
index 7975ea0..891f5e6 100644
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -103,9 +103,11 @@
 """
 
 import sys
+import threading
 import unittest
 import weakref
-from test.support import run_doctest, run_unittest, cpython_only
+from test.support import (run_doctest, run_unittest, cpython_only,
+                          check_impl_detail)
 
 
 def consts(t):
@@ -212,11 +214,106 @@
         self.assertTrue(self.called)
 
 
+if check_impl_detail(cpython=True):
+    import ctypes
+    py = ctypes.pythonapi
+    freefunc = ctypes.CFUNCTYPE(None,ctypes.c_voidp)
+
+    RequestCodeExtraIndex = py._PyEval_RequestCodeExtraIndex
+    RequestCodeExtraIndex.argtypes = (freefunc,)
+    RequestCodeExtraIndex.restype = ctypes.c_ssize_t
+
+    SetExtra = py._PyCode_SetExtra
+    SetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t, ctypes.c_voidp)
+    SetExtra.restype = ctypes.c_int
+
+    GetExtra = py._PyCode_GetExtra
+    GetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t,
+                         ctypes.POINTER(ctypes.c_voidp))
+    GetExtra.restype = ctypes.c_int
+
+    LAST_FREED = None
+    def myfree(ptr):
+        global LAST_FREED
+        LAST_FREED = ptr
+
+    FREE_FUNC = freefunc(myfree)
+    FREE_INDEX = RequestCodeExtraIndex(FREE_FUNC)
+
+    class CoExtra(unittest.TestCase):
+        def get_func(self):
+            # Defining a function causes the containing function to have a
+            # reference to the code object.  We need the code objects to go
+            # away, so we eval a lambda.
+            return eval('lambda:42')
+
+        def test_get_non_code(self):
+            f = self.get_func()
+
+            self.assertRaises(SystemError, SetExtra, 42, FREE_INDEX,
+                              ctypes.c_voidp(100))
+            self.assertRaises(SystemError, GetExtra, 42, FREE_INDEX,
+                              ctypes.c_voidp(100))
+
+        def test_bad_index(self):
+            f = self.get_func()
+            self.assertRaises(SystemError, SetExtra, f.__code__,
+                              FREE_INDEX+100, ctypes.c_voidp(100))
+            self.assertEqual(GetExtra(f.__code__, FREE_INDEX+100,
+                              ctypes.c_voidp(100)), 0)
+
+        def test_free_called(self):
+            # Verify that the provided free function gets invoked
+            # when the code object is cleaned up.
+            f = self.get_func()
+
+            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(100))
+            del f
+            self.assertEqual(LAST_FREED, 100)
+
+        def test_get_set(self):
+            # Test basic get/set round tripping.
+            f = self.get_func()
+
+            extra = ctypes.c_voidp()
+
+            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(200))
+            # reset should free...
+            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(300))
+            self.assertEqual(LAST_FREED, 200)
+
+            extra = ctypes.c_voidp()
+            GetExtra(f.__code__, FREE_INDEX, extra)
+            self.assertEqual(extra.value, 300)
+            del f
+
+        def test_free_different_thread(self):
+            # Freeing a code object on a different thread then
+            # where the co_extra was set should be safe.
+            f = self.get_func()
+            class ThreadTest(threading.Thread):
+                def __init__(self, f, test):
+                    super().__init__()
+                    self.f = f
+                    self.test = test
+                def run(self):
+                    del self.f
+                    self.test.assertEqual(LAST_FREED, 500)
+
+            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(500))
+            tt = ThreadTest(f, self)
+            del f
+            tt.start()
+            tt.join()
+            self.assertEqual(LAST_FREED, 500)
+
 def test_main(verbose=None):
     from test import test_code
     run_doctest(test_code, verbose)
-    run_unittest(CodeTest, CodeConstsTest, CodeWeakRefTest)
-
+    tests = [CodeTest, CodeConstsTest, CodeWeakRefTest]
+    if check_impl_detail(cpython=True):
+        tests.append(CoExtra)
+    run_unittest(*tests)
 
 if __name__ == "__main__":
     test_main()
diff --git a/Misc/NEWS b/Misc/NEWS
index 299d9c2..d25d191 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -10,6 +10,9 @@
 Core and Builtins
 -----------------
 
+- bpo-30604:  Move co_extra_freefuncs from per-thread to per-interpreter to
+  avoid crashes.
+
 - bpo-30597: ``print`` now shows expected input in custom error message when
   used as a Python 2 statement. Patch by Sanyam Khurana.
 
diff --git a/Objects/codeobject.c b/Objects/codeobject.c
index 46bc45d..eb860c8 100644
--- a/Objects/codeobject.c
+++ b/Objects/codeobject.c
@@ -416,11 +416,11 @@
 code_dealloc(PyCodeObject *co)
 {
     if (co->co_extra != NULL) {
-        PyThreadState *tstate = PyThreadState_Get();
+        PyInterpreterState *interp = PyThreadState_Get()->interp;
         _PyCodeObjectExtra *co_extra = co->co_extra;
 
         for (Py_ssize_t i = 0; i < co_extra->ce_size; i++) {
-            freefunc free_extra = tstate->co_extra_freefuncs[i];
+            freefunc free_extra = interp->co_extra_freefuncs[i];
 
             if (free_extra != NULL) {
                 free_extra(co_extra->ce_extras[i]);
@@ -830,8 +830,6 @@
 int
 _PyCode_GetExtra(PyObject *code, Py_ssize_t index, void **extra)
 {
-    assert(*extra == NULL);
-
     if (!PyCode_Check(code)) {
         PyErr_BadInternalCall();
         return -1;
@@ -840,8 +838,8 @@
     PyCodeObject *o = (PyCodeObject*) code;
     _PyCodeObjectExtra *co_extra = (_PyCodeObjectExtra*) o->co_extra;
 
-
     if (co_extra == NULL || co_extra->ce_size <= index) {
+        *extra = NULL;
         return 0;
     }
 
@@ -853,10 +851,10 @@
 int
 _PyCode_SetExtra(PyObject *code, Py_ssize_t index, void *extra)
 {
-    PyThreadState *tstate = PyThreadState_Get();
+    PyInterpreterState *interp = PyThreadState_Get()->interp;
 
     if (!PyCode_Check(code) || index < 0 ||
-            index >= tstate->co_extra_user_count) {
+            index >= interp->co_extra_user_count) {
         PyErr_BadInternalCall();
         return -1;
     }
@@ -871,13 +869,13 @@
         }
 
         co_extra->ce_extras = PyMem_Malloc(
-            tstate->co_extra_user_count * sizeof(void*));
+            interp->co_extra_user_count * sizeof(void*));
         if (co_extra->ce_extras == NULL) {
             PyMem_Free(co_extra);
             return -1;
         }
 
-        co_extra->ce_size = tstate->co_extra_user_count;
+        co_extra->ce_size = interp->co_extra_user_count;
 
         for (Py_ssize_t i = 0; i < co_extra->ce_size; i++) {
             co_extra->ce_extras[i] = NULL;
@@ -887,20 +885,28 @@
     }
     else if (co_extra->ce_size <= index) {
         void** ce_extras = PyMem_Realloc(
-            co_extra->ce_extras, tstate->co_extra_user_count * sizeof(void*));
+            co_extra->ce_extras, interp->co_extra_user_count * sizeof(void*));
 
         if (ce_extras == NULL) {
             return -1;
         }
 
         for (Py_ssize_t i = co_extra->ce_size;
-             i < tstate->co_extra_user_count;
+             i < interp->co_extra_user_count;
              i++) {
             ce_extras[i] = NULL;
         }
 
         co_extra->ce_extras = ce_extras;
-        co_extra->ce_size = tstate->co_extra_user_count;
+        co_extra->ce_size = interp->co_extra_user_count;
+    }
+
+    if (co_extra->ce_extras[index] != NULL) {
+        freefunc free = interp->co_extra_freefuncs[index];
+
+        if (free != NULL) {
+            free(co_extra->ce_extras[index]);
+        }
     }
 
     co_extra->ce_extras[index] = extra;
diff --git a/Python/ceval.c b/Python/ceval.c
index 6140815..4e6911a 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -5287,14 +5287,14 @@
 Py_ssize_t
 _PyEval_RequestCodeExtraIndex(freefunc free)
 {
-    PyThreadState *tstate = PyThreadState_Get();
+    PyInterpreterState *interp = PyThreadState_Get()->interp;
     Py_ssize_t new_index;
 
-    if (tstate->co_extra_user_count == MAX_CO_EXTRA_USERS - 1) {
+    if (interp->co_extra_user_count == MAX_CO_EXTRA_USERS - 1) {
         return -1;
     }
-    new_index = tstate->co_extra_user_count++;
-    tstate->co_extra_freefuncs[new_index] = free;
+    new_index = interp->co_extra_user_count++;
+    interp->co_extra_freefuncs[new_index] = free;
     return new_index;
 }
 
diff --git a/Python/pystate.c b/Python/pystate.c
index 0e62ee9..24a08eb 100644
--- a/Python/pystate.c
+++ b/Python/pystate.c
@@ -111,6 +111,7 @@
         interp->importlib = NULL;
         interp->import_func = NULL;
         interp->eval_frame = _PyEval_EvalFrameDefault;
+        interp->co_extra_user_count = 0;
 #ifdef HAVE_DLOPEN
 #if HAVE_DECL_RTLD_NOW
         interp->dlopenflags = RTLD_NOW;
@@ -281,7 +282,6 @@
 
         tstate->coroutine_wrapper = NULL;
         tstate->in_coroutine_wrapper = 0;
-        tstate->co_extra_user_count = 0;
 
         tstate->async_gen_firstiter = NULL;
         tstate->async_gen_finalizer = NULL;