Issue #11707: Fast C version of functools.cmp_to_key()
diff --git a/Lib/functools.py b/Lib/functools.py
index 3bffbac..098f6b6 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -97,7 +97,7 @@
"""Convert a cmp= function into a key= function"""
class K(object):
__slots__ = ['obj']
- def __init__(self, obj, *args):
+ def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj) < 0
@@ -115,6 +115,11 @@
raise TypeError('hash not implemented')
return K
+try:
+ from _functools import cmp_to_key
+except ImportError:
+ pass
+
_CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize")
def lru_cache(maxsize=100):
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 73a77d6..c50336e 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -435,18 +435,81 @@
self.assertEqual(self.func(add, d), "".join(d.keys()))
class TestCmpToKey(unittest.TestCase):
+
def test_cmp_to_key(self):
+ def cmp1(x, y):
+ return (x > y) - (x < y)
+ key = functools.cmp_to_key(cmp1)
+ self.assertEqual(key(3), key(3))
+ self.assertGreater(key(3), key(1))
+ def cmp2(x, y):
+ return int(x) - int(y)
+ key = functools.cmp_to_key(cmp2)
+ self.assertEqual(key(4.0), key('4'))
+ self.assertLess(key(2), key('35'))
+
+ def test_cmp_to_key_arguments(self):
+ def cmp1(x, y):
+ return (x > y) - (x < y)
+ key = functools.cmp_to_key(mycmp=cmp1)
+ self.assertEqual(key(obj=3), key(obj=3))
+ self.assertGreater(key(obj=3), key(obj=1))
+ with self.assertRaises((TypeError, AttributeError)):
+ key(3) > 1 # rhs is not a K object
+ with self.assertRaises((TypeError, AttributeError)):
+ 1 < key(3) # lhs is not a K object
+ with self.assertRaises(TypeError):
+ key = functools.cmp_to_key() # too few args
+ with self.assertRaises(TypeError):
+ key = functools.cmp_to_key(cmp1, None) # too many args
+ key = functools.cmp_to_key(cmp1)
+ with self.assertRaises(TypeError):
+ key() # too few args
+ with self.assertRaises(TypeError):
+ key(None, None) # too many args
+
+ def test_bad_cmp(self):
+ def cmp1(x, y):
+ raise ZeroDivisionError
+ key = functools.cmp_to_key(cmp1)
+ with self.assertRaises(ZeroDivisionError):
+ key(3) > key(1)
+
+ class BadCmp:
+ def __lt__(self, other):
+ raise ZeroDivisionError
+ def cmp1(x, y):
+ return BadCmp()
+ with self.assertRaises(ZeroDivisionError):
+ key(3) > key(1)
+
+ def test_obj_field(self):
+ def cmp1(x, y):
+ return (x > y) - (x < y)
+ key = functools.cmp_to_key(mycmp=cmp1)
+ self.assertEqual(key(50).obj, 50)
+
+ def test_sort_int(self):
def mycmp(x, y):
return y - x
self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
[4, 3, 2, 1, 0])
+ def test_sort_int_str(self):
+ def mycmp(x, y):
+ x, y = int(x), int(y)
+ return (x > y) - (x < y)
+ values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
+ values = sorted(values, key=functools.cmp_to_key(mycmp))
+ self.assertEqual([int(value) for value in values],
+ [0, 1, 1, 2, 3, 4, 5, 7, 10])
+
def test_hash(self):
def mycmp(x, y):
return y - x
key = functools.cmp_to_key(mycmp)
k = key(10)
- self.assertRaises(TypeError, hash(k))
+ self.assertRaises(TypeError, hash, k)
class TestTotalOrdering(unittest.TestCase):
@@ -655,6 +718,7 @@
def test_main(verbose=None):
test_classes = (
+ TestCmpToKey,
TestPartial,
TestPartialSubclass,
TestPythonPartial,
diff --git a/Misc/NEWS b/Misc/NEWS
index ef274eb..3b59906 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -97,6 +97,9 @@
- Issue #10791: Implement missing method GzipFile.read1(), allowing GzipFile
to be wrapped in a TextIOWrapper. Patch by Nadeem Vawda.
+- Issue #11707: Added a fast C version of functools.cmp_to_key().
+ Patch by Filip GruszczyĆski.
+
- Issue #11688: Add sqlite3.Connection.set_trace_callback(). Patch by
Torsten Landschoff.
diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c
index d8a283b..c657906 100644
--- a/Modules/_functoolsmodule.c
+++ b/Modules/_functoolsmodule.c
@@ -330,6 +330,165 @@
};
+/* cmp_to_key ***************************************************************/
+
+typedef struct {
+ PyObject_HEAD;
+ PyObject *cmp;
+ PyObject *object;
+} keyobject;
+
+static void
+keyobject_dealloc(keyobject *ko)
+{
+ Py_DECREF(ko->cmp);
+ Py_XDECREF(ko->object);
+ PyObject_FREE(ko);
+}
+
+static int
+keyobject_traverse(keyobject *ko, visitproc visit, void *arg)
+{
+ Py_VISIT(ko->cmp);
+ if (ko->object)
+ Py_VISIT(ko->object);
+ return 0;
+}
+
+static PyMemberDef keyobject_members[] = {
+ {"obj", T_OBJECT,
+ offsetof(keyobject, object), 0,
+ PyDoc_STR("Value wrapped by a key function.")},
+ {NULL}
+};
+
+static PyObject *
+keyobject_call(keyobject *ko, PyObject *args, PyObject *kw);
+
+static PyObject *
+keyobject_richcompare(PyObject *ko, PyObject *other, int op);
+
+static PyTypeObject keyobject_type = {
+ PyVarObject_HEAD_INIT(&PyType_Type, 0)
+ "functools.KeyWrapper", /* tp_name */
+ sizeof(keyobject), /* tp_basicsize */
+ 0, /* tp_itemsize */
+ /* methods */
+ (destructor)keyobject_dealloc, /* tp_dealloc */
+ 0, /* tp_print */
+ 0, /* tp_getattr */
+ 0, /* tp_setattr */
+ 0, /* tp_reserved */
+ 0, /* tp_repr */
+ 0, /* tp_as_number */
+ 0, /* tp_as_sequence */
+ 0, /* tp_as_mapping */
+ 0, /* tp_hash */
+ (ternaryfunc)keyobject_call, /* tp_call */
+ 0, /* tp_str */
+ PyObject_GenericGetAttr, /* tp_getattro */
+ 0, /* tp_setattro */
+ 0, /* tp_as_buffer */
+ Py_TPFLAGS_DEFAULT, /* tp_flags */
+ 0, /* tp_doc */
+ (traverseproc)keyobject_traverse, /* tp_traverse */
+ 0, /* tp_clear */
+ keyobject_richcompare, /* tp_richcompare */
+ 0, /* tp_weaklistoffset */
+ 0, /* tp_iter */
+ 0, /* tp_iternext */
+ 0, /* tp_methods */
+ keyobject_members, /* tp_members */
+ 0, /* tp_getset */
+};
+
+static PyObject *
+keyobject_call(keyobject *ko, PyObject *args, PyObject *kwds)
+{
+ PyObject *object;
+ keyobject *result;
+ static char *kwargs[] = {"obj", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:K", kwargs, &object))
+ return NULL;
+ result = PyObject_New(keyobject, &keyobject_type);
+ if (!result)
+ return NULL;
+ Py_INCREF(ko->cmp);
+ result->cmp = ko->cmp;
+ Py_INCREF(object);
+ result->object = object;
+ return (PyObject *)result;
+}
+
+static PyObject *
+keyobject_richcompare(PyObject *ko, PyObject *other, int op)
+{
+ PyObject *res;
+ PyObject *args;
+ PyObject *x;
+ PyObject *y;
+ PyObject *compare;
+ PyObject *answer;
+ static PyObject *zero;
+
+ if (zero == NULL) {
+ zero = PyLong_FromLong(0);
+ if (!zero)
+ return NULL;
+ }
+
+ if (Py_TYPE(other) != &keyobject_type){
+ PyErr_Format(PyExc_TypeError, "other argument must be K instance");
+ return NULL;
+ }
+ compare = ((keyobject *) ko)->cmp;
+ assert(compare != NULL);
+ x = ((keyobject *) ko)->object;
+ y = ((keyobject *) other)->object;
+ if (!x || !y){
+ PyErr_Format(PyExc_AttributeError, "object");
+ return NULL;
+ }
+
+ /* Call the user's comparison function and translate the 3-way
+ * result into true or false (or error).
+ */
+ args = PyTuple_New(2);
+ if (args == NULL)
+ return NULL;
+ Py_INCREF(x);
+ Py_INCREF(y);
+ PyTuple_SET_ITEM(args, 0, x);
+ PyTuple_SET_ITEM(args, 1, y);
+ res = PyObject_Call(compare, args, NULL);
+ Py_DECREF(args);
+ if (res == NULL)
+ return NULL;
+ answer = PyObject_RichCompare(res, zero, op);
+ Py_DECREF(res);
+ return answer;
+}
+
+static PyObject *
+functools_cmp_to_key(PyObject *self, PyObject *args, PyObject *kwds){
+ PyObject *cmp;
+ static char *kwargs[] = {"mycmp", NULL};
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:cmp_to_key", kwargs, &cmp))
+ return NULL;
+ keyobject *object = PyObject_New(keyobject, &keyobject_type);
+ if (!object)
+ return NULL;
+ Py_INCREF(cmp);
+ object->cmp = cmp;
+ object->object = NULL;
+ return (PyObject *)object;
+}
+
+PyDoc_STRVAR(functools_cmp_to_key_doc,
+"Convert a cmp= function into a key= function.");
+
/* reduce (used to be a builtin) ********************************************/
static PyObject *
@@ -413,6 +572,8 @@
static PyMethodDef module_methods[] = {
{"reduce", functools_reduce, METH_VARARGS, functools_reduce_doc},
+ {"cmp_to_key", functools_cmp_to_key, METH_VARARGS | METH_KEYWORDS,
+ functools_cmp_to_key_doc},
{NULL, NULL} /* sentinel */
};