[3.9] bpo-41194: The _ast module cannot be loaded more than once (GH-21290) (GH-21292)

* bpo-41194: Pass module state in Python-ast.c (GH-21284)

Rework asdl_c.py to pass the module state to functions in
Python-ast.c, instead of using astmodulestate_global.

Handle also PyState_AddModule() failure in init_types().

(cherry picked from commit 74419f0c64959bb8392fcf3659058410423038e1)

* bpo-41194: The _ast module cannot be loaded more than once (GH-21290)

Fix a crash in the _ast module: it can no longer be loaded more than
once. It now uses a global state rather than a module state.

* Move _ast module state: use a global state instead.
* Set _astmodule.m_size to -1, so the extension cannot be loaded more
  than once.

(cherry picked from commit 91e1bc18bd467a13bceb62e16fbc435b33381c82)
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py
index ce9724a..f029ca6 100755
--- a/Parser/asdl_c.py
+++ b/Parser/asdl_c.py
@@ -387,7 +387,7 @@
 
 class Obj2ModPrototypeVisitor(PickleVisitor):
     def visitProduct(self, prod, name):
-        code = "static int obj2ast_%s(PyObject* obj, %s* out, PyArena* arena);"
+        code = "static int obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena);"
         self.emit(code % (name, get_c_type(name)), 0)
 
     visitSum = visitProduct
@@ -397,7 +397,7 @@
     def funcHeader(self, name):
         ctype = get_c_type(name)
         self.emit("int", 0)
-        self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
+        self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
         self.emit("{", 0)
         self.emit("int isinstance;", 1)
         self.emit("", 0)
@@ -419,7 +419,7 @@
         self.funcHeader(name)
         for t in sum.types:
             line = ("isinstance = PyObject_IsInstance(obj, "
-                    "astmodulestate_global->%s_type);")
+                    "state->%s_type);")
             self.emit(line % (t.name,), 1)
             self.emit("if (isinstance == -1) {", 1)
             self.emit("return 1;", 2)
@@ -448,7 +448,7 @@
         for a in sum.attributes:
             self.visitField(a, name, sum=sum, depth=1)
         for t in sum.types:
-            self.emit("tp = astmodulestate_global->%s_type;" % (t.name,), 1)
+            self.emit("tp = state->%s_type;" % (t.name,), 1)
             self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1)
             self.emit("if (isinstance == -1) {", 1)
             self.emit("return 1;", 2)
@@ -479,7 +479,7 @@
     def visitProduct(self, prod, name):
         ctype = get_c_type(name)
         self.emit("int", 0)
-        self.emit("obj2ast_%s(PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
+        self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
         self.emit("{", 0)
         self.emit("PyObject* tmp = NULL;", 1)
         for f in prod.fields:
@@ -525,7 +525,7 @@
 
     def visitField(self, field, name, sum=None, prod=None, depth=0):
         ctype = get_c_type(field.type)
-        line = "if (_PyObject_LookupAttr(obj, astmodulestate_global->%s, &tmp) < 0) {"
+        line = "if (_PyObject_LookupAttr(obj, state->%s, &tmp) < 0) {"
         self.emit(line % field.name, depth)
         self.emit("return 1;", depth+1)
         self.emit("}", depth)
@@ -568,7 +568,7 @@
             self.emit("%s val;" % ctype, depth+2)
             self.emit("PyObject *tmp2 = PyList_GET_ITEM(tmp, i);", depth+2)
             self.emit("Py_INCREF(tmp2);", depth+2)
-            self.emit("res = obj2ast_%s(tmp2, &val, arena);" %
+            self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" %
                       field.type, depth+2, reflow=False)
             self.emit("Py_DECREF(tmp2);", depth+2)
             self.emit("if (res != 0) goto failed;", depth+2)
@@ -582,7 +582,7 @@
             self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2)
             self.emit("}", depth+1)
         else:
-            self.emit("res = obj2ast_%s(tmp, &%s, arena);" %
+            self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" %
                       (field.type, field.name), depth+1)
             self.emit("if (res != 0) goto failed;", depth+1)
 
@@ -604,7 +604,7 @@
 
     def visitProduct(self, prod, name):
         self.emit_type("%s_type" % name)
-        self.emit("static PyObject* ast2obj_%s(void*);" % name, 0)
+        self.emit("static PyObject* ast2obj_%s(astmodulestate *state, void*);" % name, 0)
         if prod.attributes:
             for a in prod.attributes:
                 self.emit_identifier(a.name)
@@ -634,7 +634,7 @@
             ptype = get_c_type(name)
             for t in sum.types:
                 self.emit_singleton("%s_singleton" % t.name)
-        self.emit("static PyObject* ast2obj_%s(%s);" % (name, ptype), 0)
+        self.emit("static PyObject* ast2obj_%s(astmodulestate *state, %s);" % (name, ptype), 0)
         for t in sum.types:
             self.visitConstructor(t, name)
 
@@ -691,7 +691,8 @@
     Py_ssize_t i, numfields = 0;
     int res = -1;
     PyObject *key, *value, *fields;
-    if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), astmodulestate_global->_fields, &fields) < 0) {
+    astmodulestate *state = get_global_ast_state();
+    if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
         goto cleanup;
     }
     if (fields) {
@@ -759,8 +760,9 @@
 static PyObject *
 ast_type_reduce(PyObject *self, PyObject *unused)
 {
+    astmodulestate *state = get_global_ast_state();
     PyObject *dict;
-    if (_PyObject_LookupAttr(self, astmodulestate_global->__dict__, &dict) < 0) {
+    if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
         return NULL;
     }
     if (dict) {
@@ -809,7 +811,8 @@
 };
 
 static PyObject *
-make_type(const char *type, PyObject* base, const char* const* fields, int num_fields, const char *doc)
+make_type(astmodulestate *state, const char *type, PyObject* base,
+          const char* const* fields, int num_fields, const char *doc)
 {
     PyObject *fnames, *result;
     int i;
@@ -825,16 +828,16 @@
     }
     result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOs}",
                     type, base,
-                    astmodulestate_global->_fields, fnames,
-                    astmodulestate_global->__module__,
-                    astmodulestate_global->ast,
-                    astmodulestate_global->__doc__, doc);
+                    state->_fields, fnames,
+                    state->__module__,
+                    state->ast,
+                    state->__doc__, doc);
     Py_DECREF(fnames);
     return result;
 }
 
 static int
-add_attributes(PyObject *type, const char * const *attrs, int num_fields)
+add_attributes(astmodulestate *state, PyObject *type, const char * const *attrs, int num_fields)
 {
     int i, result;
     PyObject *s, *l = PyTuple_New(num_fields);
@@ -848,14 +851,14 @@
         }
         PyTuple_SET_ITEM(l, i, s);
     }
-    result = PyObject_SetAttr(type, astmodulestate_global->_attributes, l) >= 0;
+    result = PyObject_SetAttr(type, state->_attributes, l) >= 0;
     Py_DECREF(l);
     return result;
 }
 
 /* Conversion AST -> Python */
 
-static PyObject* ast2obj_list(asdl_seq *seq, PyObject* (*func)(void*))
+static PyObject* ast2obj_list(astmodulestate *state, asdl_seq *seq, PyObject* (*func)(astmodulestate *state, void*))
 {
     Py_ssize_t i, n = asdl_seq_LEN(seq);
     PyObject *result = PyList_New(n);
@@ -863,7 +866,7 @@
     if (!result)
         return NULL;
     for (i = 0; i < n; i++) {
-        value = func(asdl_seq_GET(seq, i));
+        value = func(state, asdl_seq_GET(seq, i));
         if (!value) {
             Py_DECREF(result);
             return NULL;
@@ -873,7 +876,7 @@
     return result;
 }
 
-static PyObject* ast2obj_object(void *o)
+static PyObject* ast2obj_object(astmodulestate *Py_UNUSED(state), void *o)
 {
     if (!o)
         o = Py_None;
@@ -884,14 +887,14 @@
 #define ast2obj_identifier ast2obj_object
 #define ast2obj_string ast2obj_object
 
-static PyObject* ast2obj_int(long b)
+static PyObject* ast2obj_int(astmodulestate *Py_UNUSED(state), long b)
 {
     return PyLong_FromLong(b);
 }
 
 /* Conversion Python -> AST */
 
-static int obj2ast_object(PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_object(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (obj == Py_None)
         obj = NULL;
@@ -906,7 +909,7 @@
     return 0;
 }
 
-static int obj2ast_constant(PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_constant(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (PyArena_AddPyObject(arena, obj) < 0) {
         *out = NULL;
@@ -917,25 +920,25 @@
     return 0;
 }
 
-static int obj2ast_identifier(PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_identifier(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (!PyUnicode_CheckExact(obj) && obj != Py_None) {
         PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str");
         return 1;
     }
-    return obj2ast_object(obj, out, arena);
+    return obj2ast_object(state, obj, out, arena);
 }
 
-static int obj2ast_string(PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_string(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) {
         PyErr_SetString(PyExc_TypeError, "AST string must be of type str");
         return 1;
     }
-    return obj2ast_object(obj, out, arena);
+    return obj2ast_object(state, obj, out, arena);
 }
 
-static int obj2ast_int(PyObject* obj, int* out, PyArena* arena)
+static int obj2ast_int(astmodulestate* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena)
 {
     int i;
     if (!PyLong_Check(obj)) {
@@ -950,13 +953,13 @@
     return 0;
 }
 
-static int add_ast_fields(void)
+static int add_ast_fields(astmodulestate *state)
 {
     PyObject *empty_tuple;
     empty_tuple = PyTuple_New(0);
     if (!empty_tuple ||
-        PyObject_SetAttrString(astmodulestate_global->AST_type, "_fields", empty_tuple) < 0 ||
-        PyObject_SetAttrString(astmodulestate_global->AST_type, "_attributes", empty_tuple) < 0) {
+        PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 ||
+        PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) {
         Py_XDECREF(empty_tuple);
         return -1;
     }
@@ -968,18 +971,12 @@
 
         self.emit("static int init_types(void)",0)
         self.emit("{", 0)
-        self.emit("PyObject *m;", 1)
-        self.emit("if (PyState_FindModule(&_astmodule) == NULL) {", 1)
-        self.emit("m = PyModule_Create(&_astmodule);", 2)
-        self.emit("if (!m) return 0;", 2)
-        self.emit("PyState_AddModule(m, &_astmodule);", 2)
-        self.emit("}", 1)
-        self.emit("astmodulestate *state = astmodulestate_global;", 1)
+        self.emit("astmodulestate *state = get_global_ast_state();", 1)
         self.emit("if (state->initialized) return 1;", 1)
-        self.emit("if (init_identifiers() < 0) return 0;", 1)
+        self.emit("if (init_identifiers(state) < 0) return 0;", 1)
         self.emit("state->AST_type = PyType_FromSpec(&AST_type_spec);", 1)
         self.emit("if (!state->AST_type) return 0;", 1)
-        self.emit("if (add_ast_fields() < 0) return 0;", 1)
+        self.emit("if (add_ast_fields(state) < 0) return 0;", 1)
         for dfn in mod.dfns:
             self.visit(dfn)
         self.emit("state->initialized = 1;", 1)
@@ -991,31 +988,31 @@
             fields = name+"_fields"
         else:
             fields = "NULL"
-        self.emit('state->%s_type = make_type("%s", state->AST_type, %s, %d,' %
+        self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' %
                         (name, name, fields, len(prod.fields)), 1)
         self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False)
         self.emit("if (!state->%s_type) return 0;" % name, 1)
         self.emit_type("AST_type")
         self.emit_type("%s_type" % name)
         if prod.attributes:
-            self.emit("if (!add_attributes(state->%s_type, %s_attributes, %d)) return 0;" %
+            self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
                             (name, name, len(prod.attributes)), 1)
         else:
-            self.emit("if (!add_attributes(state->%s_type, NULL, 0)) return 0;" % name, 1)
+            self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
         self.emit_defaults(name, prod.fields, 1)
         self.emit_defaults(name, prod.attributes, 1)
 
     def visitSum(self, sum, name):
-        self.emit('state->%s_type = make_type("%s", state->AST_type, NULL, 0,' %
+        self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' %
                   (name, name), 1)
         self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False)
         self.emit_type("%s_type" % name)
         self.emit("if (!state->%s_type) return 0;" % name, 1)
         if sum.attributes:
-            self.emit("if (!add_attributes(state->%s_type, %s_attributes, %d)) return 0;" %
+            self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
                             (name, name, len(sum.attributes)), 1)
         else:
-            self.emit("if (!add_attributes(state->%s_type, NULL, 0)) return 0;" % name, 1)
+            self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
         self.emit_defaults(name, sum.attributes, 1)
         simple = is_simple(sum)
         for t in sum.types:
@@ -1026,7 +1023,7 @@
             fields = cons.name+"_fields"
         else:
             fields = "NULL"
-        self.emit('state->%s_type = make_type("%s", state->%s_type, %s, %d,' %
+        self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' %
                             (cons.name, cons.name, name, fields, len(cons.fields)), 1)
         self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False)
         self.emit("if (!state->%s_type) return 0;" % cons.name, 1)
@@ -1052,14 +1049,20 @@
         self.emit("PyMODINIT_FUNC", 0)
         self.emit("PyInit__ast(void)", 0)
         self.emit("{", 0)
-        self.emit("PyObject *m;", 1)
-        self.emit("if (!init_types()) return NULL;", 1)
-        self.emit('m = PyState_FindModule(&_astmodule);', 1)
-        self.emit("if (!m) return NULL;", 1)
-        self.emit('if (PyModule_AddObject(m, "AST", astmodulestate_global->AST_type) < 0) {', 1)
+        self.emit("PyObject *m = PyModule_Create(&_astmodule);", 1)
+        self.emit("if (!m) {", 1)
+        self.emit("return NULL;", 2)
+        self.emit("}", 1)
+        self.emit('astmodulestate *state = get_ast_state(m);', 1)
+        self.emit('', 1)
+
+        self.emit("if (!init_types()) {", 1)
+        self.emit("goto error;", 2)
+        self.emit("}", 1)
+        self.emit('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {', 1)
         self.emit('goto error;', 2)
         self.emit('}', 1)
-        self.emit('Py_INCREF(astmodulestate(m)->AST_type);', 1)
+        self.emit('Py_INCREF(state->AST_type);', 1)
         self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1)
         self.emit("goto error;", 2)
         self.emit('}', 1)
@@ -1072,6 +1075,7 @@
         for dfn in mod.dfns:
             self.visit(dfn)
         self.emit("return m;", 1)
+        self.emit("", 0)
         self.emit("error:", 0)
         self.emit("Py_DECREF(m);", 1)
         self.emit("return NULL;", 1)
@@ -1090,10 +1094,10 @@
 
     def addObj(self, name):
         self.emit("if (PyModule_AddObject(m, \"%s\", "
-                  "astmodulestate_global->%s_type) < 0) {" % (name, name), 1)
+                  "state->%s_type) < 0) {" % (name, name), 1)
         self.emit("goto error;", 2)
         self.emit('}', 1)
-        self.emit("Py_INCREF(astmodulestate(m)->%s_type);" % name, 1)
+        self.emit("Py_INCREF(state->%s_type);" % name, 1)
 
 
 _SPECIALIZED_SEQUENCES = ('stmt', 'expr')
@@ -1127,7 +1131,7 @@
     def func_begin(self, name):
         ctype = get_c_type(name)
         self.emit("PyObject*", 0)
-        self.emit("ast2obj_%s(void* _o)" % (name), 0)
+        self.emit("ast2obj_%s(astmodulestate *state, void* _o)" % (name), 0)
         self.emit("{", 0)
         self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
         self.emit("PyObject *result = NULL, *value = NULL;", 1)
@@ -1135,7 +1139,6 @@
         self.emit('if (!o) {', 1)
         self.emit("Py_RETURN_NONE;", 2)
         self.emit("}", 1)
-        self.emit('', 0)
 
     def func_end(self):
         self.emit("return result;", 1)
@@ -1157,43 +1160,43 @@
             self.visitConstructor(t, i + 1, name)
         self.emit("}", 1)
         for a in sum.attributes:
-            self.emit("value = ast2obj_%s(o->%s);" % (a.type, a.name), 1)
+            self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
             self.emit("if (!value) goto failed;", 1)
-            self.emit('if (PyObject_SetAttr(result, astmodulestate_global->%s, value) < 0)' % a.name, 1)
+            self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
             self.emit('goto failed;', 2)
             self.emit('Py_DECREF(value);', 1)
         self.func_end()
 
     def simpleSum(self, sum, name):
-        self.emit("PyObject* ast2obj_%s(%s_ty o)" % (name, name), 0)
+        self.emit("PyObject* ast2obj_%s(astmodulestate *state, %s_ty o)" % (name, name), 0)
         self.emit("{", 0)
         self.emit("switch(o) {", 1)
         for t in sum.types:
             self.emit("case %s:" % t.name, 2)
-            self.emit("Py_INCREF(astmodulestate_global->%s_singleton);" % t.name, 3)
-            self.emit("return astmodulestate_global->%s_singleton;" % t.name, 3)
+            self.emit("Py_INCREF(state->%s_singleton);" % t.name, 3)
+            self.emit("return state->%s_singleton;" % t.name, 3)
         self.emit("}", 1)
         self.emit("Py_UNREACHABLE();", 1);
         self.emit("}", 0)
 
     def visitProduct(self, prod, name):
         self.func_begin(name)
-        self.emit("tp = (PyTypeObject *)astmodulestate_global->%s_type;" % name, 1)
+        self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1)
         self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1);
         self.emit("if (!result) return NULL;", 1)
         for field in prod.fields:
             self.visitField(field, name, 1, True)
         for a in prod.attributes:
-            self.emit("value = ast2obj_%s(o->%s);" % (a.type, a.name), 1)
+            self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
             self.emit("if (!value) goto failed;", 1)
-            self.emit("if (PyObject_SetAttr(result, astmodulestate_global->%s, value) < 0)" % a.name, 1)
+            self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
             self.emit('goto failed;', 2)
             self.emit('Py_DECREF(value);', 1)
         self.func_end()
 
     def visitConstructor(self, cons, enum, name):
         self.emit("case %s_kind:" % cons.name, 1)
-        self.emit("tp = (PyTypeObject *)astmodulestate_global->%s_type;" % cons.name, 2)
+        self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2)
         self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2);
         self.emit("if (!result) goto failed;", 2)
         for f in cons.fields:
@@ -1209,7 +1212,7 @@
             value = "o->v.%s.%s" % (name, field.name)
         self.set(field, value, depth)
         emit("if (!value) goto failed;", 0)
-        emit("if (PyObject_SetAttr(result, astmodulestate_global->%s, value) == -1)" % field.name, 0)
+        emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0)
         emit("goto failed;", 1)
         emit("Py_DECREF(value);", 0)
 
@@ -1237,14 +1240,14 @@
                 self.emit("if (!value) goto failed;", depth+1)
                 self.emit("for(i = 0; i < n; i++)", depth+1)
                 # This cannot fail, so no need for error handling
-                self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop((cmpop_ty)asdl_seq_GET(%s, i)));" % value,
+                self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop(state, (cmpop_ty)asdl_seq_GET(%s, i)));" % value,
                           depth+2, reflow=False)
                 self.emit("}", depth)
             else:
-                self.emit("value = ast2obj_list(%s, ast2obj_%s);" % (value, field.type), depth)
+                self.emit("value = ast2obj_list(state, %s, ast2obj_%s);" % (value, field.type), depth)
         else:
             ctype = get_c_type(field.type)
-            self.emit("value = ast2obj_%s(%s);" % (field.type, value), depth, reflow=False)
+            self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
 
 
 class PartingShots(StaticVisitor):
@@ -1252,15 +1255,17 @@
     CODE = """
 PyObject* PyAST_mod2obj(mod_ty t)
 {
-    if (!init_types())
+    if (!init_types()) {
         return NULL;
-    return ast2obj_mod(t);
+    }
+
+    astmodulestate *state = get_global_ast_state();
+    return ast2obj_mod(state, t);
 }
 
 /* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
 mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
 {
-    PyObject *req_type[3];
     const char * const req_name[] = {"Module", "Expression", "Interactive"};
     int isinstance;
 
@@ -1268,14 +1273,17 @@
         return NULL;
     }
 
-    req_type[0] = astmodulestate_global->Module_type;
-    req_type[1] = astmodulestate_global->Expression_type;
-    req_type[2] = astmodulestate_global->Interactive_type;
+    astmodulestate *state = get_global_ast_state();
+    PyObject *req_type[3];
+    req_type[0] = state->Module_type;
+    req_type[1] = state->Expression_type;
+    req_type[2] = state->Interactive_type;
 
     assert(0 <= mode && mode <= 2);
 
-    if (!init_types())
+    if (!init_types()) {
         return NULL;
+    }
 
     isinstance = PyObject_IsInstance(ast, req_type[mode]);
     if (isinstance == -1)
@@ -1287,7 +1295,7 @@
     }
 
     mod_ty res = NULL;
-    if (obj2ast_mod(ast, &res, arena) != 0)
+    if (obj2ast_mod(state, ast, &res, arena) != 0)
         return NULL;
     else
         return res;
@@ -1295,9 +1303,12 @@
 
 int PyAST_Check(PyObject* obj)
 {
-    if (!init_types())
+    if (!init_types()) {
         return -1;
-    return PyObject_IsInstance(obj, astmodulestate_global->AST_type);
+    }
+
+    astmodulestate *state = get_global_ast_state();
+    return PyObject_IsInstance(obj, state->AST_type);
 }
 """
 
@@ -1347,22 +1358,30 @@
         f.write('    PyObject *' + s + ';\n')
     f.write('} astmodulestate;\n\n')
     f.write("""
-#define astmodulestate(o) ((astmodulestate *)PyModule_GetState(o))
+static astmodulestate global_ast_state;
+
+static astmodulestate *
+get_ast_state(PyObject *Py_UNUSED(module))
+{
+    return &global_ast_state;
+}
 
 static int astmodule_clear(PyObject *module)
 {
+    astmodulestate *state = get_ast_state(module);
 """)
     for s in module_state:
-        f.write("    Py_CLEAR(astmodulestate(module)->" + s + ');\n')
+        f.write("    Py_CLEAR(state->" + s + ');\n')
     f.write("""
     return 0;
 }
 
 static int astmodule_traverse(PyObject *module, visitproc visit, void* arg)
 {
+    astmodulestate *state = get_ast_state(module);
 """)
     for s in module_state:
-        f.write("    Py_VISIT(astmodulestate(module)->" + s + ');\n')
+        f.write("    Py_VISIT(state->" + s + ');\n')
     f.write("""
     return 0;
 }
@@ -1373,22 +1392,18 @@
 
 static struct PyModuleDef _astmodule = {
     PyModuleDef_HEAD_INIT,
-    "_ast",
-    NULL,
-    sizeof(astmodulestate),
-    NULL,
-    NULL,
-    astmodule_traverse,
-    astmodule_clear,
-    astmodule_free,
+    .m_name = "_ast",
+    .m_size = -1,
+    .m_traverse = astmodule_traverse,
+    .m_clear = astmodule_clear,
+    .m_free = astmodule_free,
 };
 
-#define astmodulestate_global ((astmodulestate *)PyModule_GetState(PyState_FindModule(&_astmodule)))
+#define get_global_ast_state() (&global_ast_state)
 
 """)
-    f.write('static int init_identifiers(void)\n')
+    f.write('static int init_identifiers(astmodulestate *state)\n')
     f.write('{\n')
-    f.write('    astmodulestate *state = astmodulestate_global;\n')
     for identifier in state_strings:
         f.write('    if ((state->' + identifier)
         f.write(' = PyUnicode_InternFromString("')