bpo-43892: Make match patterns explicit in the AST (GH-25585)

Co-authored-by: Brandt Bucher <brandtbucher@gmail.com>
diff --git a/Python/ast.c b/Python/ast.c
index 2b96543..1fc83f6 100644
--- a/Python/ast.c
+++ b/Python/ast.c
@@ -7,6 +7,7 @@
 #include "pycore_pystate.h"       // _PyThreadState_GET()
 
 #include <assert.h>
+#include <stdbool.h>
 
 struct validator {
     int recursion_depth;            /* current recursion depth */
@@ -18,6 +19,7 @@ static int validate_exprs(struct validator *, asdl_expr_seq*, expr_context_ty, i
 static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
 static int validate_stmt(struct validator *, stmt_ty);
 static int validate_expr(struct validator *, expr_ty, expr_context_ty);
+static int validate_pattern(struct validator *, pattern_ty);
 
 static int
 validate_name(PyObject *name)
@@ -88,9 +90,9 @@ expr_context_name(expr_context_ty ctx)
         return "Store";
     case Del:
         return "Del";
-    default:
-        Py_UNREACHABLE();
+    // No default case so compiler emits warning for unhandled cases
     }
+    Py_UNREACHABLE();
 }
 
 static int
@@ -180,7 +182,7 @@ validate_constant(struct validator *state, PyObject *value)
 static int
 validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
 {
-    int ret;
+    int ret = -1;
     if (++state->recursion_depth > state->recursion_limit) {
         PyErr_SetString(PyExc_RecursionError,
                         "maximum recursion depth exceeded during compilation");
@@ -351,34 +353,216 @@ validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
     case NamedExpr_kind:
         ret = validate_expr(state, exp->v.NamedExpr.value, Load);
         break;
-    case MatchAs_kind:
-        PyErr_SetString(PyExc_ValueError,
-                        "MatchAs is only valid in match_case patterns");
-        return 0;
-    case MatchOr_kind:
-        PyErr_SetString(PyExc_ValueError,
-                        "MatchOr is only valid in match_case patterns");
-        return 0;
     /* This last case doesn't have any checking. */
     case Name_kind:
         ret = 1;
         break;
-    default:
+    // No default case so compiler emits warning for unhandled cases
+    }
+    if (ret < 0) {
         PyErr_SetString(PyExc_SystemError, "unexpected expression");
-        return 0;
+        ret = 0;
     }
     state->recursion_depth--;
     return ret;
 }
 
+
+// Note: the ensure_literal_* functions are only used to validate a restricted
+//       set of non-recursive literals that have already been checked with
+//       validate_expr, so they don't accept the validator state
 static int
-validate_pattern(expr_ty p)
+ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary)
 {
-    // Coming soon (thanks Batuhan)!
+    assert(exp->kind == Constant_kind);
+    PyObject *value = exp->v.Constant.value;
+    return (allow_real && PyFloat_CheckExact(value)) ||
+           (allow_real && PyLong_CheckExact(value)) ||
+           (allow_imaginary && PyComplex_CheckExact(value));
+}
+
+static int
+ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary)
+{
+    assert(exp->kind == UnaryOp_kind);
+    // Must be negation ...
+    if (exp->v.UnaryOp.op != USub) {
+        return 0;
+    }
+    // ... of a constant ...
+    expr_ty operand = exp->v.UnaryOp.operand;
+    if (operand->kind != Constant_kind) {
+        return 0;
+    }
+    // ... number
+    return ensure_literal_number(operand, allow_real, allow_imaginary);
+}
+
+static int
+ensure_literal_complex(expr_ty exp)
+{
+    assert(exp->kind == BinOp_kind);
+    expr_ty left = exp->v.BinOp.left;
+    expr_ty right = exp->v.BinOp.right;
+    // Ensure op is addition or subtraction
+    if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) {
+        return 0;
+    }
+    // Check LHS is a real number (potentially signed)
+    switch (left->kind)
+    {
+        case Constant_kind:
+            if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) {
+                return 0;
+            }
+            break;
+        case UnaryOp_kind:
+            if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) {
+                return 0;
+            }
+            break;
+        default:
+            return 0;
+    }
+    // Check RHS is an imaginary number (no separate sign allowed)
+    switch (right->kind)
+    {
+        case Constant_kind:
+            if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) {
+                return 0;
+            }
+            break;
+        default:
+            return 0;
+    }
     return 1;
 }
 
 static int
+validate_pattern_match_value(struct validator *state, expr_ty exp)
+{
+    if (!validate_expr(state, exp, Load)) {
+        return 0;
+    }
+
+    switch (exp->kind)
+    {
+        case Constant_kind:
+        case Attribute_kind:
+            // Constants and attribute lookups are always permitted
+            return 1;
+        case UnaryOp_kind:
+            // Negated numbers are permitted (whether real or imaginary)
+            // Compiler will complain if AST folding doesn't create a constant
+            if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) {
+                return 1;
+            }
+            break;
+        case BinOp_kind:
+            // Complex literals are permitted
+            // Compiler will complain if AST folding doesn't create a constant
+            if (ensure_literal_complex(exp)) {
+                return 1;
+            }
+            break;
+        default:
+            break;
+    }
+    PyErr_SetString(PyExc_SyntaxError,
+        "patterns may only match literals and attribute lookups");
+    return 0;
+}
+
+static int
+validate_pattern(struct validator *state, pattern_ty p)
+{
+    int ret = -1;
+    if (++state->recursion_depth > state->recursion_limit) {
+        PyErr_SetString(PyExc_RecursionError,
+                        "maximum recursion depth exceeded during compilation");
+        return 0;
+    }
+    // Coming soon: https://bugs.python.org/issue43897 (thanks Batuhan)!
+    // TODO: Ensure no subnodes use "_" as an ordinary identifier
+    switch (p->kind) {
+        case MatchValue_kind:
+            ret = validate_pattern_match_value(state, p->v.MatchValue.value);
+            break;
+        case MatchSingleton_kind:
+            // TODO: Check constant is specifically None, True, or False
+            ret = validate_constant(state, p->v.MatchSingleton.value);
+            break;
+        case MatchSequence_kind:
+            // TODO: Validate all subpatterns
+            // return validate_patterns(state, p->v.MatchSequence.patterns);
+            ret = 1;
+            break;
+        case MatchMapping_kind:
+            // TODO: check "rest" target name is valid
+            if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
+                PyErr_SetString(PyExc_ValueError,
+                                "MatchMapping doesn't have the same number of keys as patterns");
+                return 0;
+            }
+            // null_ok=0 for key expressions, as rest-of-mapping is captured in "rest"
+            // TODO: replace with more restrictive expression validator, as per MatchValue above
+            if (!validate_exprs(state, p->v.MatchMapping.keys, Load, /*null_ok=*/ 0)) {
+                return 0;
+            }
+            // TODO: Validate all subpatterns
+            // ret = validate_patterns(state, p->v.MatchMapping.patterns);
+            ret = 1;
+            break;
+        case MatchClass_kind:
+            if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
+                PyErr_SetString(PyExc_ValueError,
+                                "MatchClass doesn't have the same number of keyword attributes as patterns");
+                return 0;
+            }
+            // TODO: Restrict cls lookup to being a name or attribute
+            if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
+                return 0;
+            }
+            // TODO: Validate all subpatterns
+            // return validate_patterns(state, p->v.MatchClass.patterns) &&
+            //        validate_patterns(state, p->v.MatchClass.kwd_patterns);
+            ret = 1;
+            break;
+        case MatchStar_kind:
+            // TODO: check target name is valid
+            ret = 1;
+            break;
+        case MatchAs_kind:
+            // TODO: check target name is valid
+            if (p->v.MatchAs.pattern == NULL) {
+                ret = 1;
+            }
+            else if (p->v.MatchAs.name == NULL) {
+                PyErr_SetString(PyExc_ValueError,
+                                "MatchAs must specify a target name if a pattern is given");
+                return 0;
+            }
+            else {
+                ret = validate_pattern(state, p->v.MatchAs.pattern);
+            }
+            break;
+        case MatchOr_kind:
+            // TODO: Validate all subpatterns
+            // return validate_patterns(state, p->v.MatchOr.patterns);
+            ret = 1;
+            break;
+    // No default case, so the compiler will emit a warning if new pattern
+    // kinds are added without being handled here
+    }
+    if (ret < 0) {
+        PyErr_SetString(PyExc_SystemError, "unexpected pattern");
+        ret = 0;
+    }
+    state->recursion_depth--;
+    return ret;
+}
+
+static int
 _validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
 {
     if (asdl_seq_LEN(seq))
@@ -404,7 +588,7 @@ validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner)
 static int
 validate_stmt(struct validator *state, stmt_ty stmt)
 {
-    int ret;
+    int ret = -1;
     Py_ssize_t i;
     if (++state->recursion_depth > state->recursion_limit) {
         PyErr_SetString(PyExc_RecursionError,
@@ -502,7 +686,7 @@ validate_stmt(struct validator *state, stmt_ty stmt)
         }
         for (i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
             match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
-            if (!validate_pattern(m->pattern)
+            if (!validate_pattern(state, m->pattern)
                 || (m->guard && !validate_expr(state, m->guard, Load))
                 || !validate_body(state, m->body, "match_case")) {
                 return 0;
@@ -582,9 +766,11 @@ validate_stmt(struct validator *state, stmt_ty stmt)
     case Continue_kind:
         ret = 1;
         break;
-    default:
+    // No default case so compiler emits warning for unhandled cases
+    }
+    if (ret < 0) {
         PyErr_SetString(PyExc_SystemError, "unexpected statement");
-        return 0;
+        ret = 0;
     }
     state->recursion_depth--;
     return ret;
@@ -635,7 +821,7 @@ validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ct
 int
 _PyAST_Validate(mod_ty mod)
 {
-    int res = 0;
+    int res = -1;
     struct validator state;
     PyThreadState *tstate;
     int recursion_limit = Py_GetRecursionLimit();
@@ -663,10 +849,16 @@ _PyAST_Validate(mod_ty mod)
     case Expression_kind:
         res = validate_expr(&state, mod->v.Expression.body, Load);
         break;
-    default:
-        PyErr_SetString(PyExc_SystemError, "impossible module node");
-        res = 0;
+    case FunctionType_kind:
+        res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) &&
+              validate_expr(&state, mod->v.FunctionType.returns, Load);
         break;
+    // No default case so compiler emits warning for unhandled cases
+    }
+
+    if (res < 0) {
+        PyErr_SetString(PyExc_SystemError, "impossible module node");
+        return 0;
     }
 
     /* Check that the recursion depth counting balanced correctly */