bpo-42609: Check recursion depth in the AST validator and optimizer (GH-23744)

diff --git a/Python/ast_opt.c b/Python/ast_opt.c
index dea20da..6eb514e 100644
--- a/Python/ast_opt.c
+++ b/Python/ast_opt.c
@@ -2,6 +2,7 @@
 #include "Python.h"
 #include "pycore_ast.h"           // _PyAST_GetDocString()
 #include "pycore_compile.h"       // _PyASTOptimizeState
+#include "pycore_pystate.h"       // _PyThreadState_GET()
 
 
 static int
@@ -488,6 +489,11 @@ astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
 static int
 astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
 {
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+                        "maximum recursion depth exceeded during compilation");
+        return 0;
+    }
     switch (node_->kind) {
     case BoolOp_kind:
         CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
@@ -586,6 +592,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     case Name_kind:
         if (node_->v.Name.ctx == Load &&
                 _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
+            state->recursion_depth--;
             return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
         }
         break;
@@ -602,6 +609,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     // No default case, so the compiler will emit a warning if new expression
     // kinds are added without being handled here
     }
+    state->recursion_depth--;
     return 1;
 }
 
@@ -648,6 +656,11 @@ astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
 static int
 astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
 {
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+                        "maximum recursion depth exceeded during compilation");
+        return 0;
+    }
     switch (node_->kind) {
     case FunctionDef_kind:
         CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
@@ -757,6 +770,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
     // No default case, so the compiler will emit a warning if new statement
     // kinds are added without being handled here
     }
+    state->recursion_depth--;
     return 1;
 }
 
@@ -906,10 +920,38 @@ astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat
 #undef CALL_SEQ
 #undef CALL_INT_SEQ
 
+/* See comments in symtable.c. */
+#define COMPILER_STACK_FRAME_SCALE 3
+
 int
 _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
 {
+    PyThreadState *tstate;
+    int recursion_limit = Py_GetRecursionLimit();
+    int starting_recursion_depth;
+
+    /* Setup recursion depth check counters */
+    tstate = _PyThreadState_GET();
+    if (!tstate) {
+        return 0;
+    }
+    /* Be careful here to prevent overflow. */
+    starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+        tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth;
+    state->recursion_depth = starting_recursion_depth;
+    state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+        recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
+
     int ret = astfold_mod(mod, arena, state);
     assert(ret || PyErr_Occurred());
+
+    /* Check that the recursion depth counting balanced correctly */
+    if (ret && state->recursion_depth != starting_recursion_depth) {
+        PyErr_Format(PyExc_SystemError,
+            "AST optimizer recursion depth mismatch (before=%d, after=%d)",
+            starting_recursion_depth, state->recursion_depth);
+        return 0;
+    }
+
     return ret;
 }