bpo-41796: Make _ast module state per interpreter (GH-23024)

The ast module internal state is now per interpreter.

* Rename "astmodulestate" to "struct ast_state"
* Add pycore_ast.h internal header: the ast_state structure is now
  declared in pycore_ast.h.
* Add PyInterpreterState.ast (struct ast_state)
* Remove get_ast_state()
* Rename get_global_ast_state() to get_ast_state()
* PyAST_obj2mod() now handles get_ast_state() failures
diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py
index 481261c..9a833e8 100755
--- a/Parser/asdl_c.py
+++ b/Parser/asdl_c.py
@@ -3,6 +3,7 @@
 
 import os
 import sys
+import textwrap
 
 from argparse import ArgumentParser
 from pathlib import Path
@@ -11,7 +12,7 @@
 
 TABSIZE = 4
 MAX_COL = 80
-AUTOGEN_MESSAGE = "/* File automatically generated by {}. */\n\n"
+AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
 
 def get_c_type(name):
     """Return a string for the C name of the type.
@@ -414,7 +415,7 @@ def visitField(self, sum):
 
 class Obj2ModPrototypeVisitor(PickleVisitor):
     def visitProduct(self, prod, name):
-        code = "static int obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena);"
+        code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena);"
         self.emit(code % (name, get_c_type(name)), 0)
 
     visitSum = visitProduct
@@ -424,7 +425,7 @@ class Obj2ModVisitor(PickleVisitor):
     def funcHeader(self, name):
         ctype = get_c_type(name)
         self.emit("int", 0)
-        self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
+        self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
         self.emit("{", 0)
         self.emit("int isinstance;", 1)
         self.emit("", 0)
@@ -506,7 +507,7 @@ def visitSum(self, sum, name):
     def visitProduct(self, prod, name):
         ctype = get_c_type(name)
         self.emit("int", 0)
-        self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
+        self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
         self.emit("{", 0)
         self.emit("PyObject* tmp = NULL;", 1)
         for f in prod.fields:
@@ -640,7 +641,7 @@ class PyTypesDeclareVisitor(PickleVisitor):
 
     def visitProduct(self, prod, name):
         self.emit_type("%s_type" % name)
-        self.emit("static PyObject* ast2obj_%s(astmodulestate *state, void*);" % name, 0)
+        self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
         if prod.attributes:
             for a in prod.attributes:
                 self.emit_identifier(a.name)
@@ -670,7 +671,7 @@ def visitSum(self, sum, name):
             ptype = get_c_type(name)
             for t in sum.types:
                 self.emit_singleton("%s_singleton" % t.name)
-        self.emit("static PyObject* ast2obj_%s(astmodulestate *state, %s);" % (name, ptype), 0)
+        self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
         for t in sum.types:
             self.visitConstructor(t, name)
 
@@ -725,7 +726,7 @@ def visitModule(self, mod):
 static int
 ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
 {
-    astmodulestate *state = get_global_ast_state();
+    struct ast_state *state = get_ast_state();
     if (state == NULL) {
         return -1;
     }
@@ -801,7 +802,7 @@ def visitModule(self, mod):
 static PyObject *
 ast_type_reduce(PyObject *self, PyObject *unused)
 {
-    astmodulestate *state = get_global_ast_state();
+    struct ast_state *state = get_ast_state();
     if (state == NULL) {
         return NULL;
     }
@@ -856,7 +857,7 @@ def visitModule(self, mod):
 };
 
 static PyObject *
-make_type(astmodulestate *state, const char *type, PyObject* base,
+make_type(struct ast_state *state, const char *type, PyObject* base,
           const char* const* fields, int num_fields, const char *doc)
 {
     PyObject *fnames, *result;
@@ -882,7 +883,7 @@ def visitModule(self, mod):
 }
 
 static int
-add_attributes(astmodulestate *state, PyObject *type, const char * const *attrs, int num_fields)
+add_attributes(struct ast_state *state, PyObject *type, const char * const *attrs, int num_fields)
 {
     int i, result;
     PyObject *s, *l = PyTuple_New(num_fields);
@@ -903,7 +904,7 @@ def visitModule(self, mod):
 
 /* Conversion AST -> Python */
 
-static PyObject* ast2obj_list(astmodulestate *state, asdl_seq *seq, PyObject* (*func)(astmodulestate *state, void*))
+static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
 {
     Py_ssize_t i, n = asdl_seq_LEN(seq);
     PyObject *result = PyList_New(n);
@@ -921,7 +922,7 @@ def visitModule(self, mod):
     return result;
 }
 
-static PyObject* ast2obj_object(astmodulestate *Py_UNUSED(state), void *o)
+static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
 {
     if (!o)
         o = Py_None;
@@ -932,14 +933,14 @@ def visitModule(self, mod):
 #define ast2obj_identifier ast2obj_object
 #define ast2obj_string ast2obj_object
 
-static PyObject* ast2obj_int(astmodulestate *Py_UNUSED(state), long b)
+static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
 {
     return PyLong_FromLong(b);
 }
 
 /* Conversion Python -> AST */
 
-static int obj2ast_object(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (obj == Py_None)
         obj = NULL;
@@ -954,7 +955,7 @@ def visitModule(self, mod):
     return 0;
 }
 
-static int obj2ast_constant(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (PyArena_AddPyObject(arena, obj) < 0) {
         *out = NULL;
@@ -965,7 +966,7 @@ def visitModule(self, mod):
     return 0;
 }
 
-static int obj2ast_identifier(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (!PyUnicode_CheckExact(obj) && obj != Py_None) {
         PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str");
@@ -974,7 +975,7 @@ def visitModule(self, mod):
     return obj2ast_object(state, obj, out, arena);
 }
 
-static int obj2ast_string(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena)
+static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
 {
     if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) {
         PyErr_SetString(PyExc_TypeError, "AST string must be of type str");
@@ -983,7 +984,7 @@ def visitModule(self, mod):
     return obj2ast_object(state, obj, out, arena);
 }
 
-static int obj2ast_int(astmodulestate* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena)
+static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena)
 {
     int i;
     if (!PyLong_Check(obj)) {
@@ -998,7 +999,7 @@ def visitModule(self, mod):
     return 0;
 }
 
-static int add_ast_fields(astmodulestate *state)
+static int add_ast_fields(struct ast_state *state)
 {
     PyObject *empty_tuple;
     empty_tuple = PyTuple_New(0);
@@ -1014,7 +1015,7 @@ def visitModule(self, mod):
 
 """, 0, reflow=False)
 
-        self.emit("static int init_types(astmodulestate *state)",0)
+        self.emit("static int init_types(struct ast_state *state)",0)
         self.emit("{", 0)
         self.emit("if (state->initialized) return 1;", 1)
         self.emit("if (init_identifiers(state) < 0) return 0;", 1)
@@ -1093,12 +1094,10 @@ def visitModule(self, mod):
         self.emit("static int", 0)
         self.emit("astmodule_exec(PyObject *m)", 0)
         self.emit("{", 0)
-        self.emit('astmodulestate *state = get_ast_state(m);', 1)
-        self.emit("", 0)
-
-        self.emit("if (!init_types(state)) {", 1)
-        self.emit("return -1;", 2)
-        self.emit("}", 1)
+        self.emit('struct ast_state *state = get_ast_state();', 1)
+        self.emit('if (state == NULL) {', 1)
+        self.emit('return -1;', 2)
+        self.emit('}', 1)
         self.emit('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {', 1)
         self.emit('return -1;', 2)
         self.emit('}', 1)
@@ -1126,7 +1125,7 @@ def visitModule(self, mod):
 static struct PyModuleDef _astmodule = {
     PyModuleDef_HEAD_INIT,
     .m_name = "_ast",
-    // The _ast module uses a global state (global_ast_state).
+    // The _ast module uses a per-interpreter state (PyInterpreterState.ast)
     .m_size = 0,
     .m_slots = astmodule_slots,
 };
@@ -1169,7 +1168,7 @@ class ObjVisitor(PickleVisitor):
     def func_begin(self, name):
         ctype = get_c_type(name)
         self.emit("PyObject*", 0)
-        self.emit("ast2obj_%s(astmodulestate *state, void* _o)" % (name), 0)
+        self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
         self.emit("{", 0)
         self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
         self.emit("PyObject *result = NULL, *value = NULL;", 1)
@@ -1206,7 +1205,7 @@ def visitSum(self, sum, name):
         self.func_end()
 
     def simpleSum(self, sum, name):
-        self.emit("PyObject* ast2obj_%s(astmodulestate *state, %s_ty o)" % (name, name), 0)
+        self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
         self.emit("{", 0)
         self.emit("switch(o) {", 1)
         for t in sum.types:
@@ -1280,7 +1279,7 @@ class PartingShots(StaticVisitor):
     CODE = """
 PyObject* PyAST_mod2obj(mod_ty t)
 {
-    astmodulestate *state = get_global_ast_state();
+    struct ast_state *state = get_ast_state();
     if (state == NULL) {
         return NULL;
     }
@@ -1297,7 +1296,11 @@ class PartingShots(StaticVisitor):
         return NULL;
     }
 
-    astmodulestate *state = get_global_ast_state();
+    struct ast_state *state = get_ast_state();
+    if (state == NULL) {
+        return NULL;
+    }
+
     PyObject *req_type[3];
     req_type[0] = state->Module_type;
     req_type[1] = state->Expression_type;
@@ -1323,7 +1326,7 @@ class PartingShots(StaticVisitor):
 
 int PyAST_Check(PyObject* obj)
 {
-    astmodulestate *state = get_global_ast_state();
+    struct ast_state *state = get_ast_state();
     if (state == NULL) {
         return -1;
     }
@@ -1341,7 +1344,35 @@ def visit(self, object):
             v.emit("", 0)
 
 
-def generate_module_def(f, mod):
+def generate_ast_state(module_state, f):
+    f.write('struct ast_state {\n')
+    f.write('    int initialized;\n')
+    for s in module_state:
+        f.write('    PyObject *' + s + ';\n')
+    f.write('};')
+
+
+def generate_ast_fini(module_state, f):
+    f.write("""
+void _PyAST_Fini(PyThreadState *tstate)
+{
+#ifdef Py_BUILD_CORE
+    struct ast_state *state = &tstate->interp->ast;
+#else
+    struct ast_state *state = &global_ast_state;
+#endif
+
+""")
+    for s in module_state:
+        f.write("    Py_CLEAR(state->" + s + ');\n')
+    f.write("""
+    state->initialized = 0;
+}
+
+""")
+
+
+def generate_module_def(mod, f, internal_h):
     # Gather all the data needed for ModuleSpec
     visitor_list = set()
     with open(os.devnull, "w") as devnull:
@@ -1371,50 +1402,64 @@ def generate_module_def(f, mod):
             module_state.add(tp)
     state_strings = sorted(state_strings)
     module_state = sorted(module_state)
-    f.write('typedef struct {\n')
-    f.write('    int initialized;\n')
-    for s in module_state:
-        f.write('    PyObject *' + s + ';\n')
-    f.write('} astmodulestate;\n\n')
+
+    generate_ast_state(module_state, internal_h)
+
+    print(textwrap.dedent(f"""
+        #ifdef Py_BUILD_CORE
+        #  include "pycore_ast.h"           // struct ast_state
+        #  include "pycore_interp.h"        // _PyInterpreterState.ast
+        #  include "pycore_pystate.h"       // _PyInterpreterState_GET()
+        #else
+    """).strip(), file=f)
+
+    generate_ast_state(module_state, f)
+
+    print(textwrap.dedent(f"""
+        #endif   // Py_BUILD_CORE
+    """).rstrip(), file=f)
+
     f.write("""
 // Forward declaration
-static int init_types(astmodulestate *state);
+static int init_types(struct ast_state *state);
 
-// bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
-static astmodulestate global_ast_state = {0};
-
-static astmodulestate*
-get_global_ast_state(void)
+#ifdef Py_BUILD_CORE
+static struct ast_state*
+get_ast_state(void)
 {
-    astmodulestate* state = &global_ast_state;
+    PyInterpreterState *interp = _PyInterpreterState_GET();
+    struct ast_state *state = &interp->ast;
     if (!init_types(state)) {
         return NULL;
     }
     return state;
 }
+#else
+static struct ast_state global_ast_state;
 
-static astmodulestate*
-get_ast_state(PyObject* Py_UNUSED(module))
+static struct ast_state*
+get_ast_state(void)
 {
-    astmodulestate* state = get_global_ast_state();
-    // get_ast_state() must only be called after _ast module is imported,
-    // and astmodule_exec() calls init_types()
-    assert(state != NULL);
+    struct ast_state *state = &global_ast_state;
+    if (!init_types(state)) {
+        return NULL;
+    }
     return state;
 }
-
-void _PyAST_Fini(PyThreadState *tstate)
-{
-    astmodulestate* state = &global_ast_state;
+#endif   // Py_BUILD_CORE
 """)
-    for s in module_state:
-        f.write("    Py_CLEAR(state->" + s + ');\n')
-    f.write("""
-    state->initialized = 0;
-}
 
+    # f-string for {mod.name}
+    f.write(f"""
+// Include {mod.name}-ast.h after pycore_interp.h to avoid conflicts
+// with the Yield macro redefined by <winbase.h>
+#include "{mod.name}-ast.h"
+#include "structmember.h"
 """)
-    f.write('static int init_identifiers(astmodulestate *state)\n')
+
+    generate_ast_fini(module_state, f)
+
+    f.write('static int init_identifiers(struct ast_state *state)\n')
     f.write('{\n')
     for identifier in state_strings:
         f.write('    if ((state->' + identifier)
@@ -1423,7 +1468,7 @@ def generate_module_def(f, mod):
     f.write('    return 1;\n')
     f.write('};\n\n')
 
-def write_header(f, mod):
+def write_header(mod, f):
     f.write('#ifndef Py_PYTHON_AST_H\n')
     f.write('#define Py_PYTHON_AST_H\n')
     f.write('#ifdef __cplusplus\n')
@@ -1452,15 +1497,39 @@ def write_header(f, mod):
     f.write('#endif\n')
     f.write('#endif /* !Py_PYTHON_AST_H */\n')
 
-def write_source(f, mod):
-    f.write('#include <stddef.h>\n')
-    f.write('\n')
-    f.write('#include "Python.h"\n')
-    f.write('#include "%s-ast.h"\n' % mod.name)
-    f.write('#include "structmember.h"         // PyMemberDef\n')
-    f.write('\n')
 
-    generate_module_def(f, mod)
+def write_internal_h_header(mod, f):
+    print(textwrap.dedent("""
+        #ifndef Py_INTERNAL_AST_H
+        #define Py_INTERNAL_AST_H
+        #ifdef __cplusplus
+        extern "C" {
+        #endif
+
+        #ifndef Py_BUILD_CORE
+        #  error "this header requires Py_BUILD_CORE define"
+        #endif
+    """).lstrip(), file=f)
+
+
+def write_internal_h_footer(mod, f):
+    print(textwrap.dedent("""
+
+        #ifdef __cplusplus
+        }
+        #endif
+        #endif /* !Py_INTERNAL_AST_H */
+    """), file=f)
+
+
+def write_source(mod, f, internal_h_file):
+    print(textwrap.dedent(f"""
+        #include <stddef.h>
+
+        #include "Python.h"
+    """), file=f)
+
+    generate_module_def(mod, f, internal_h_file)
 
     v = ChainOfVisitors(
         SequenceConstructorVisitor(f),
@@ -1475,27 +1544,37 @@ def write_source(f, mod):
     )
     v.visit(mod)
 
-def main(input_file, c_file, h_file, dump_module=False):
+def main(input_filename, c_filename, h_filename, internal_h_filename, dump_module=False):
     auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
-    mod = asdl.parse(input_file)
+    mod = asdl.parse(input_filename)
     if dump_module:
         print('Parsed Module:')
         print(mod)
     if not asdl.check(mod):
         sys.exit(1)
-    for file, writer in (c_file, write_source), (h_file, write_header):
-        if file is not None:
-            with file.open("w") as f:
-                f.write(auto_gen_msg)
-                writer(f, mod)
-            print(file, "regenerated.")
+
+    with c_filename.open("w") as c_file, \
+         h_filename.open("w") as h_file, \
+         internal_h_filename.open("w") as internal_h_file:
+        c_file.write(auto_gen_msg)
+        h_file.write(auto_gen_msg)
+        internal_h_file.write(auto_gen_msg)
+
+        write_internal_h_header(mod, internal_h_file)
+        write_source(mod, c_file, internal_h_file)
+        write_header(mod, h_file)
+        write_internal_h_footer(mod, internal_h_file)
+
+    print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.")
 
 if __name__ == "__main__":
     parser = ArgumentParser()
     parser.add_argument("input_file", type=Path)
-    parser.add_argument("-C", "--c-file", type=Path, default=None)
-    parser.add_argument("-H", "--h-file", type=Path, default=None)
+    parser.add_argument("-C", "--c-file", type=Path, required=True)
+    parser.add_argument("-H", "--h-file", type=Path, required=True)
+    parser.add_argument("-I", "--internal-h-file", type=Path, required=True)
     parser.add_argument("-d", "--dump-module", action="store_true")
 
-    options = parser.parse_args()
-    main(**vars(options))
+    args = parser.parse_args()
+    main(args.input_file, args.c_file, args.h_file,
+         args.internal_h_file, args.dump_module)