bpo-16500: Don't use string constants for os.register_at_fork() behavior (#1834)

Instead use keyword only arguments to os.register_at_fork for each of the scenarios.
Updates the documentation for clarity.
diff --git a/Doc/c-api/sys.rst b/Doc/c-api/sys.rst
index c6777d6..95d9d65 100644
--- a/Doc/c-api/sys.rst
+++ b/Doc/c-api/sys.rst
@@ -49,9 +49,10 @@
 
 .. c:function:: void PyOS_AfterFork_Child()
 
-   Function to update some internal state after a process fork.  This
-   should be called from the child process after calling :c:func:`fork`
-   or any similar function that clones the current process.
+   Function to update internal interpreter state after a process fork.
+   This must be called from the child process after calling :c:func:`fork`,
+   or any similar function that clones the current process, if there is
+   any chance the process will call back into the Python interpreter.
    Only available on systems where :c:func:`fork` is defined.
 
    .. versionadded:: 3.7
diff --git a/Doc/library/os.rst b/Doc/library/os.rst
index 28921ad..86add0c 100644
--- a/Doc/library/os.rst
+++ b/Doc/library/os.rst
@@ -3280,16 +3280,22 @@
    subprocesses.
 
 
-.. function:: register_at_fork(func, when)
+.. function:: register_at_fork(*, before=None, after_in_parent=None, \
+                               after_in_child=None)
 
-   Register *func* as a function to be executed when a new child process
-   is forked.  *when* is a string specifying at which point the function is
-   called and can take the following values:
+   Register callables to be executed when a new child process is forked
+   using :func:`os.fork` or similar process cloning APIs.
+   The parameters are optional and keyword-only.
+   Each specifies a different call point.
 
-   * *"before"* means the function is called before forking a child process;
-   * *"parent"* means the function is called from the parent process after
-     forking a child process;
-   * *"child"* means the function is called from the child process.
+   * *before* is a function called before forking a child process.
+   * *after_in_parent* is a function called from the parent process
+     after forking a child process.
+   * *after_in_child* is a function called from the child process.
+
+   These calls are only made if control is expected to return to the
+   Python interpreter.  A typical :mod:`subprocess` launch will not
+   trigger them as the child is not going to re-enter the interpreter.
 
    Functions registered for execution before forking are called in
    reverse registration order.  Functions registered for execution
@@ -3300,6 +3306,8 @@
    call those functions, unless it explicitly calls :c:func:`PyOS_BeforeFork`,
    :c:func:`PyOS_AfterFork_Parent` and :c:func:`PyOS_AfterFork_Child`.
 
+   There is no way to unregister a function.
+
    Availability: Unix.
 
    .. versionadded:: 3.7
diff --git a/Lib/random.py b/Lib/random.py
index 52df7d8..b54d524 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -765,7 +765,7 @@
 getrandbits = _inst.getrandbits
 
 if hasattr(_os, "fork"):
-    _os.register_at_fork(_inst.seed, when='child')
+    _os.register_at_fork(after_in_child=_inst.seed)
 
 
 if __name__ == '__main__':
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index a72f83c..412b079 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -189,19 +189,41 @@
             self.assertEqual(pid, res.si_pid)
 
     @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
-    def test_register_after_fork(self):
+    def test_register_at_fork(self):
+        with self.assertRaises(TypeError, msg="Positional args not allowed"):
+            os.register_at_fork(lambda: None)
+        with self.assertRaises(TypeError, msg="Args must be callable"):
+            os.register_at_fork(before=2)
+        with self.assertRaises(TypeError, msg="Args must be callable"):
+            os.register_at_fork(after_in_child="three")
+        with self.assertRaises(TypeError, msg="Args must be callable"):
+            os.register_at_fork(after_in_parent=b"Five")
+        with self.assertRaises(TypeError, msg="Args must not be None"):
+            os.register_at_fork(before=None)
+        with self.assertRaises(TypeError, msg="Args must not be None"):
+            os.register_at_fork(after_in_child=None)
+        with self.assertRaises(TypeError, msg="Args must not be None"):
+            os.register_at_fork(after_in_parent=None)
+        with self.assertRaises(TypeError, msg="Invalid arg was allowed"):
+            # Ensure a combination of valid and invalid is an error.
+            os.register_at_fork(before=None, after_in_parent=lambda: 3)
+        with self.assertRaises(TypeError, msg="Invalid arg was allowed"):
+            # Ensure a combination of valid and invalid is an error.
+            os.register_at_fork(before=lambda: None, after_in_child='')
+        # We test actual registrations in their own process so as not to
+        # pollute this one.  There is no way to unregister for cleanup.
         code = """if 1:
             import os
 
             r, w = os.pipe()
             fin_r, fin_w = os.pipe()
 
-            os.register_at_fork(lambda: os.write(w, b'A'), when='before')
-            os.register_at_fork(lambda: os.write(w, b'B'), when='before')
-            os.register_at_fork(lambda: os.write(w, b'C'), when='parent')
-            os.register_at_fork(lambda: os.write(w, b'D'), when='parent')
-            os.register_at_fork(lambda: os.write(w, b'E'), when='child')
-            os.register_at_fork(lambda: os.write(w, b'F'), when='child')
+            os.register_at_fork(before=lambda: os.write(w, b'A'))
+            os.register_at_fork(after_in_parent=lambda: os.write(w, b'C'))
+            os.register_at_fork(after_in_child=lambda: os.write(w, b'E'))
+            os.register_at_fork(before=lambda: os.write(w, b'B'),
+                                after_in_parent=lambda: os.write(w, b'D'),
+                                after_in_child=lambda: os.write(w, b'F'))
 
             pid = os.fork()
             if pid == 0:
diff --git a/Lib/threading.py b/Lib/threading.py
index 2eaf49a..92c2ab3 100644
--- a/Lib/threading.py
+++ b/Lib/threading.py
@@ -1359,5 +1359,5 @@
         assert len(_active) == 1
 
 
-if hasattr(_os, "fork"):
-    _os.register_at_fork(_after_fork, when="child")
+if hasattr(_os, "register_at_fork"):
+    _os.register_at_fork(after_in_child=_after_fork)
diff --git a/Modules/_posixsubprocess.c b/Modules/_posixsubprocess.c
index 5228fec..8c8777c 100644
--- a/Modules/_posixsubprocess.c
+++ b/Modules/_posixsubprocess.c
@@ -651,14 +651,6 @@
             goto cleanup;
     }
 
-    if (preexec_fn != Py_None) {
-        preexec_fn_args_tuple = PyTuple_New(0);
-        if (!preexec_fn_args_tuple)
-            goto cleanup;
-        PyOS_BeforeFork();
-        need_after_fork = 1;
-    }
-
     if (cwd_obj != Py_None) {
         if (PyUnicode_FSConverter(cwd_obj, &cwd_obj2) == 0)
             goto cleanup;
@@ -668,6 +660,17 @@
         cwd_obj2 = NULL;
     }
 
+    /* This must be the last thing done before fork() because we do not
+     * want to call PyOS_BeforeFork() if there is any chance of another
+     * error leading to the cleanup: code without calling fork(). */
+    if (preexec_fn != Py_None) {
+        preexec_fn_args_tuple = PyTuple_New(0);
+        if (!preexec_fn_args_tuple)
+            goto cleanup;
+        PyOS_BeforeFork();
+        need_after_fork = 1;
+    }
+
     pid = fork();
     if (pid == 0) {
         /* Child process */
@@ -722,8 +725,6 @@
     return PyLong_FromPid(pid);
 
 cleanup:
-    if (need_after_fork)
-        PyOS_AfterFork_Parent();
     if (envp)
         _Py_FreeCharPArray(envp);
     if (argv)
diff --git a/Modules/clinic/posixmodule.c.h b/Modules/clinic/posixmodule.c.h
index 2c919e1..8e1b55a 100644
--- a/Modules/clinic/posixmodule.c.h
+++ b/Modules/clinic/posixmodule.c.h
@@ -1828,40 +1828,44 @@
 #if defined(HAVE_FORK)
 
 PyDoc_STRVAR(os_register_at_fork__doc__,
-"register_at_fork($module, func, /, when)\n"
+"register_at_fork($module, /, *, before=None, after_in_child=None,\n"
+"                 after_in_parent=None)\n"
 "--\n"
 "\n"
-"Register a callable object to be called when forking.\n"
+"Register callables to be called when forking a new process.\n"
 "\n"
-"  func\n"
-"    Function or callable\n"
-"  when\n"
-"    \'before\', \'child\' or \'parent\'\n"
+"  before\n"
+"    A callable to be called in the parent before the fork() syscall.\n"
+"  after_in_child\n"
+"    A callable to be called in the child after fork().\n"
+"  after_in_parent\n"
+"    A callable to be called in the parent after fork().\n"
 "\n"
-"\'before\' callbacks are called in reverse order before forking.\n"
-"\'child\' callbacks are called in order after forking, in the child process.\n"
-"\'parent\' callbacks are called in order after forking, in the parent process.");
+"\'before\' callbacks are called in reverse order.\n"
+"\'after_in_child\' and \'after_in_parent\' callbacks are called in order.");
 
 #define OS_REGISTER_AT_FORK_METHODDEF    \
     {"register_at_fork", (PyCFunction)os_register_at_fork, METH_FASTCALL, os_register_at_fork__doc__},
 
 static PyObject *
-os_register_at_fork_impl(PyObject *module, PyObject *func, const char *when);
+os_register_at_fork_impl(PyObject *module, PyObject *before,
+                         PyObject *after_in_child, PyObject *after_in_parent);
 
 static PyObject *
 os_register_at_fork(PyObject *module, PyObject **args, Py_ssize_t nargs, PyObject *kwnames)
 {
     PyObject *return_value = NULL;
-    static const char * const _keywords[] = {"", "when", NULL};
-    static _PyArg_Parser _parser = {"Os:register_at_fork", _keywords, 0};
-    PyObject *func;
-    const char *when;
+    static const char * const _keywords[] = {"before", "after_in_child", "after_in_parent", NULL};
+    static _PyArg_Parser _parser = {"|$OOO:register_at_fork", _keywords, 0};
+    PyObject *before = NULL;
+    PyObject *after_in_child = NULL;
+    PyObject *after_in_parent = NULL;
 
     if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &_parser,
-        &func, &when)) {
+        &before, &after_in_child, &after_in_parent)) {
         goto exit;
     }
-    return_value = os_register_at_fork_impl(module, func, when);
+    return_value = os_register_at_fork_impl(module, before, after_in_child, after_in_parent);
 
 exit:
     return return_value;
@@ -6541,4 +6545,4 @@
 #ifndef OS_GETRANDOM_METHODDEF
     #define OS_GETRANDOM_METHODDEF
 #endif /* !defined(OS_GETRANDOM_METHODDEF) */
-/*[clinic end generated code: output=699e11c5579a104e input=a9049054013a1b77]*/
+/*[clinic end generated code: output=dce741f527ddbfa4 input=a9049054013a1b77]*/
diff --git a/Modules/posixmodule.c b/Modules/posixmodule.c
index be8a66d..f4a2167 100644
--- a/Modules/posixmodule.c
+++ b/Modules/posixmodule.c
@@ -465,6 +465,8 @@
 static int
 register_at_forker(PyObject **lst, PyObject *func)
 {
+    if (func == NULL)  /* nothing to register? do nothing. */
+        return 0;
     if (*lst == NULL) {
         *lst = PyList_New(0);
         if (*lst == NULL)
@@ -5309,52 +5311,67 @@
 
 
 #ifdef HAVE_FORK
+
+/* Helper function to validate arguments.
+   Returns 0 on success.  non-zero on failure with a TypeError raised.
+   If obj is non-NULL it must be callable.  */
+static int
+check_null_or_callable(PyObject *obj, const char* obj_name)
+{
+    if (obj && !PyCallable_Check(obj)) {
+        PyErr_Format(PyExc_TypeError, "'%s' must be callable, not %s",
+                     obj_name, Py_TYPE(obj)->tp_name);
+        return -1;
+    }
+    return 0;
+}
+
 /*[clinic input]
 os.register_at_fork
 
-    func: object
-        Function or callable
-    /
-    when: str
-        'before', 'child' or 'parent'
+    *
+    before: object=NULL
+        A callable to be called in the parent before the fork() syscall.
+    after_in_child: object=NULL
+        A callable to be called in the child after fork().
+    after_in_parent: object=NULL
+        A callable to be called in the parent after fork().
 
-Register a callable object to be called when forking.
+Register callables to be called when forking a new process.
 
-'before' callbacks are called in reverse order before forking.
-'child' callbacks are called in order after forking, in the child process.
-'parent' callbacks are called in order after forking, in the parent process.
+'before' callbacks are called in reverse order.
+'after_in_child' and 'after_in_parent' callbacks are called in order.
 
 [clinic start generated code]*/
 
 static PyObject *
-os_register_at_fork_impl(PyObject *module, PyObject *func, const char *when)
-/*[clinic end generated code: output=8943be81a644750c input=5fc05efa4d42eb84]*/
+os_register_at_fork_impl(PyObject *module, PyObject *before,
+                         PyObject *after_in_child, PyObject *after_in_parent)
+/*[clinic end generated code: output=5398ac75e8e97625 input=cd1187aa85d2312e]*/
 {
     PyInterpreterState *interp;
-    PyObject **lst;
 
-    if (!PyCallable_Check(func)) {
-        PyErr_Format(PyExc_TypeError,
-                     "expected callable object, got %R", Py_TYPE(func));
+    if (!before && !after_in_child && !after_in_parent) {
+        PyErr_SetString(PyExc_TypeError, "At least one argument is required.");
+        return NULL;
+    }
+    if (check_null_or_callable(before, "before") ||
+        check_null_or_callable(after_in_child, "after_in_child") ||
+        check_null_or_callable(after_in_parent, "after_in_parent")) {
         return NULL;
     }
     interp = PyThreadState_Get()->interp;
 
-    if (!strcmp(when, "before"))
-        lst = &interp->before_forkers;
-    else if (!strcmp(when, "child"))
-        lst = &interp->after_forkers_child;
-    else if (!strcmp(when, "parent"))
-        lst = &interp->after_forkers_parent;
-    else {
-        PyErr_Format(PyExc_ValueError, "unexpected value for `when`: '%s'",
-                     when);
+    if (register_at_forker(&interp->before_forkers, before)) {
         return NULL;
     }
-    if (register_at_forker(lst, func))
+    if (register_at_forker(&interp->after_forkers_child, after_in_child)) {
         return NULL;
-    else
-        Py_RETURN_NONE;
+    }
+    if (register_at_forker(&interp->after_forkers_parent, after_in_parent)) {
+        return NULL;
+    }
+    Py_RETURN_NONE;
 }
 #endif /* HAVE_FORK */