Issue #12803: SSLContext.load_cert_chain() now accepts a password argument
to be used if the private key is encrypted.  Patch by Adam Simpkins.
diff --git a/Modules/_ssl.c b/Modules/_ssl.c
index 1a367f2..b203ce4 100644
--- a/Modules/_ssl.c
+++ b/Modules/_ssl.c
@@ -18,16 +18,21 @@
 
 #ifdef WITH_THREAD
 #include "pythread.h"
+#define PySSL_BEGIN_ALLOW_THREADS_S(save) \
+    do { if (_ssl_locks_count>0) { (save) = PyEval_SaveThread(); } } while (0)
+#define PySSL_END_ALLOW_THREADS_S(save) \
+    do { if (_ssl_locks_count>0) { PyEval_RestoreThread(save); } } while (0)
 #define PySSL_BEGIN_ALLOW_THREADS { \
             PyThreadState *_save = NULL;  \
-            if (_ssl_locks_count>0) {_save = PyEval_SaveThread();}
-#define PySSL_BLOCK_THREADS     if (_ssl_locks_count>0){PyEval_RestoreThread(_save)};
-#define PySSL_UNBLOCK_THREADS   if (_ssl_locks_count>0){_save = PyEval_SaveThread()};
-#define PySSL_END_ALLOW_THREADS if (_ssl_locks_count>0){PyEval_RestoreThread(_save);} \
-         }
+            PySSL_BEGIN_ALLOW_THREADS_S(_save);
+#define PySSL_BLOCK_THREADS     PySSL_END_ALLOW_THREADS_S(_save);
+#define PySSL_UNBLOCK_THREADS   PySSL_BEGIN_ALLOW_THREADS_S(_save);
+#define PySSL_END_ALLOW_THREADS PySSL_END_ALLOW_THREADS_S(_save); }
 
 #else   /* no WITH_THREAD */
 
+#define PySSL_BEGIN_ALLOW_THREADS_S(save)
+#define PySSL_END_ALLOW_THREADS_S(save)
 #define PySSL_BEGIN_ALLOW_THREADS
 #define PySSL_BLOCK_THREADS
 #define PySSL_UNBLOCK_THREADS
@@ -1635,19 +1640,118 @@
     return 0;
 }
 
+typedef struct {
+    PyThreadState *thread_state;
+    PyObject *callable;
+    char *password;
+    Py_ssize_t size;
+    int error;
+} _PySSLPasswordInfo;
+
+static int
+_pwinfo_set(_PySSLPasswordInfo *pw_info, PyObject* password,
+            const char *bad_type_error)
+{
+    /* Set the password and size fields of a _PySSLPasswordInfo struct
+       from a unicode, bytes, or byte array object.
+       The password field will be dynamically allocated and must be freed
+       by the caller */
+    PyObject *password_bytes = NULL;
+    const char *data = NULL;
+    Py_ssize_t size;
+
+    if (PyUnicode_Check(password)) {
+        password_bytes = PyUnicode_AsEncodedString(password, NULL, NULL);
+        if (!password_bytes) {
+            goto error;
+        }
+        data = PyBytes_AS_STRING(password_bytes);
+        size = PyBytes_GET_SIZE(password_bytes);
+    } else if (PyBytes_Check(password)) {
+        data = PyBytes_AS_STRING(password);
+        size = PyBytes_GET_SIZE(password);
+    } else if (PyByteArray_Check(password)) {
+        data = PyByteArray_AS_STRING(password);
+        size = PyByteArray_GET_SIZE(password);
+    } else {
+        PyErr_SetString(PyExc_TypeError, bad_type_error);
+        goto error;
+    }
+
+    free(pw_info->password);
+    pw_info->password = malloc(size);
+    if (!pw_info->password) {
+        PyErr_SetString(PyExc_MemoryError,
+                        "unable to allocate password buffer");
+        goto error;
+    }
+    memcpy(pw_info->password, data, size);
+    pw_info->size = size;
+
+    Py_XDECREF(password_bytes);
+    return 1;
+
+error:
+    Py_XDECREF(password_bytes);
+    return 0;
+}
+
+static int
+_password_callback(char *buf, int size, int rwflag, void *userdata)
+{
+    _PySSLPasswordInfo *pw_info = (_PySSLPasswordInfo*) userdata;
+    PyObject *fn_ret = NULL;
+
+    PySSL_END_ALLOW_THREADS_S(pw_info->thread_state);
+
+    if (pw_info->callable) {
+        fn_ret = PyObject_CallFunctionObjArgs(pw_info->callable, NULL);
+        if (!fn_ret) {
+            /* TODO: It would be nice to move _ctypes_add_traceback() into the
+               core python API, so we could use it to add a frame here */
+            goto error;
+        }
+
+        if (!_pwinfo_set(pw_info, fn_ret,
+                         "password callback must return a string")) {
+            goto error;
+        }
+        Py_CLEAR(fn_ret);
+    }
+
+    if (pw_info->size > size) {
+        PyErr_Format(PyExc_ValueError,
+                     "password cannot be longer than %d bytes", size);
+        goto error;
+    }
+
+    PySSL_BEGIN_ALLOW_THREADS_S(pw_info->thread_state);
+    memcpy(buf, pw_info->password, pw_info->size);
+    return pw_info->size;
+
+error:
+    Py_XDECREF(fn_ret);
+    PySSL_BEGIN_ALLOW_THREADS_S(pw_info->thread_state);
+    pw_info->error = 1;
+    return -1;
+}
+
 static PyObject *
 load_cert_chain(PySSLContext *self, PyObject *args, PyObject *kwds)
 {
-    char *kwlist[] = {"certfile", "keyfile", NULL};
-    PyObject *certfile, *keyfile = NULL;
+    char *kwlist[] = {"certfile", "keyfile", "password", NULL};
+    PyObject *certfile, *keyfile = NULL, *password = NULL;
     PyObject *certfile_bytes = NULL, *keyfile_bytes = NULL;
+    pem_password_cb *orig_passwd_cb = self->ctx->default_passwd_callback;
+    void *orig_passwd_userdata = self->ctx->default_passwd_callback_userdata;
+    _PySSLPasswordInfo pw_info = { NULL, NULL, NULL, 0, 0 };
     int r;
 
     errno = 0;
     ERR_clear_error();
     if (!PyArg_ParseTupleAndKeywords(args, kwds,
-        "O|O:load_cert_chain", kwlist,
-        &certfile, &keyfile))
+        "O|OO:load_cert_chain", kwlist,
+        &certfile, &keyfile, &password))
         return NULL;
     if (keyfile == Py_None)
         keyfile = NULL;
@@ -1661,12 +1765,26 @@
                         "keyfile should be a valid filesystem path");
         goto error;
     }
-    PySSL_BEGIN_ALLOW_THREADS
+    if (password && password != Py_None) {
+        if (PyCallable_Check(password)) {
+            pw_info.callable = password;
+        } else if (!_pwinfo_set(&pw_info, password,
+                                "password should be a string or callable")) {
+            goto error;
+        }
+        SSL_CTX_set_default_passwd_cb(self->ctx, _password_callback);
+        SSL_CTX_set_default_passwd_cb_userdata(self->ctx, &pw_info);
+    }
+    PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
     r = SSL_CTX_use_certificate_chain_file(self->ctx,
         PyBytes_AS_STRING(certfile_bytes));
-    PySSL_END_ALLOW_THREADS
+    PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
     if (r != 1) {
-        if (errno != 0) {
+        if (pw_info.error) {
+            ERR_clear_error();
+            /* the password callback has already set the error information */
+        }
+        else if (errno != 0) {
             ERR_clear_error();
             PyErr_SetFromErrno(PyExc_IOError);
         }
@@ -1675,33 +1793,43 @@
         }
         goto error;
     }
-    PySSL_BEGIN_ALLOW_THREADS
+    PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
     r = SSL_CTX_use_PrivateKey_file(self->ctx,
         PyBytes_AS_STRING(keyfile ? keyfile_bytes : certfile_bytes),
         SSL_FILETYPE_PEM);
-    PySSL_END_ALLOW_THREADS
-    Py_XDECREF(keyfile_bytes);
-    Py_XDECREF(certfile_bytes);
+    PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
+    Py_CLEAR(keyfile_bytes);
+    Py_CLEAR(certfile_bytes);
     if (r != 1) {
-        if (errno != 0) {
+        if (pw_info.error) {
+            ERR_clear_error();
+            /* the password callback has already set the error information */
+        }
+        else if (errno != 0) {
             ERR_clear_error();
             PyErr_SetFromErrno(PyExc_IOError);
         }
         else {
             _setSSLError(NULL, 0, __FILE__, __LINE__);
         }
-        return NULL;
+        goto error;
     }
-    PySSL_BEGIN_ALLOW_THREADS
+    PySSL_BEGIN_ALLOW_THREADS_S(pw_info.thread_state);
     r = SSL_CTX_check_private_key(self->ctx);
-    PySSL_END_ALLOW_THREADS
+    PySSL_END_ALLOW_THREADS_S(pw_info.thread_state);
     if (r != 1) {
         _setSSLError(NULL, 0, __FILE__, __LINE__);
-        return NULL;
+        goto error;
     }
+    SSL_CTX_set_default_passwd_cb(self->ctx, orig_passwd_cb);
+    SSL_CTX_set_default_passwd_cb_userdata(self->ctx, orig_passwd_userdata);
+    free(pw_info.password);
     Py_RETURN_NONE;
 
 error:
+    SSL_CTX_set_default_passwd_cb(self->ctx, orig_passwd_cb);
+    SSL_CTX_set_default_passwd_cb_userdata(self->ctx, orig_passwd_userdata);
+    free(pw_info.password);
     Py_XDECREF(keyfile_bytes);
     Py_XDECREF(certfile_bytes);
     return NULL;