bpo-42609: Check recursion depth in the AST validator and optimizer (GH-23744)
diff --git a/Python/ast_opt.c b/Python/ast_opt.c
index dea20da..6eb514e 100644
--- a/Python/ast_opt.c
+++ b/Python/ast_opt.c
@@ -2,6 +2,7 @@
#include "Python.h"
#include "pycore_ast.h" // _PyAST_GetDocString()
#include "pycore_compile.h" // _PyASTOptimizeState
+#include "pycore_pystate.h" // _PyThreadState_GET()
static int
@@ -488,6 +489,11 @@ astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
static int
astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
+ if (++state->recursion_depth > state->recursion_limit) {
+ PyErr_SetString(PyExc_RecursionError,
+ "maximum recursion depth exceeded during compilation");
+ return 0;
+ }
switch (node_->kind) {
case BoolOp_kind:
CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
@@ -586,6 +592,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
case Name_kind:
if (node_->v.Name.ctx == Load &&
_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
+ state->recursion_depth--;
return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
}
break;
@@ -602,6 +609,7 @@ astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new expression
// kinds are added without being handled here
}
+ state->recursion_depth--;
return 1;
}
@@ -648,6 +656,11 @@ astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
static int
astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
{
+ if (++state->recursion_depth > state->recursion_limit) {
+ PyErr_SetString(PyExc_RecursionError,
+ "maximum recursion depth exceeded during compilation");
+ return 0;
+ }
switch (node_->kind) {
case FunctionDef_kind:
CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
@@ -757,6 +770,7 @@ astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
// No default case, so the compiler will emit a warning if new statement
// kinds are added without being handled here
}
+ state->recursion_depth--;
return 1;
}
@@ -906,10 +920,38 @@ astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat
#undef CALL_SEQ
#undef CALL_INT_SEQ
+/* See comments in symtable.c. */
+#define COMPILER_STACK_FRAME_SCALE 3
+
int
_PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
{
+ PyThreadState *tstate;
+ int recursion_limit = Py_GetRecursionLimit();
+ int starting_recursion_depth;
+
+ /* Setup recursion depth check counters */
+ tstate = _PyThreadState_GET();
+ if (!tstate) {
+ return 0;
+ }
+ /* Be careful here to prevent overflow. */
+ starting_recursion_depth = (tstate->recursion_depth < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+ tstate->recursion_depth * COMPILER_STACK_FRAME_SCALE : tstate->recursion_depth;
+ state->recursion_depth = starting_recursion_depth;
+ state->recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
+ recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
+
int ret = astfold_mod(mod, arena, state);
assert(ret || PyErr_Occurred());
+
+ /* Check that the recursion depth counting balanced correctly */
+ if (ret && state->recursion_depth != starting_recursion_depth) {
+ PyErr_Format(PyExc_SystemError,
+ "AST optimizer recursion depth mismatch (before=%d, after=%d)",
+ starting_recursion_depth, state->recursion_depth);
+ return 0;
+ }
+
return ret;
}