diff --git a/cjson.c b/cjson.c index 860162e..3099ae3 100644 --- a/cjson.c +++ b/cjson.c @@ -15,12 +15,12 @@ typedef struct JSONData { int all_unicode; // make all output strings unicode if true } JSONData; -static PyObject* encode_object(PyObject *object); +static PyObject* encode_object(PyObject *object, PyObject *fallback); static PyObject* encode_string(PyObject *object); static PyObject* encode_unicode(PyObject *object); -static PyObject* encode_tuple(PyObject *object); -static PyObject* encode_list(PyObject *object); -static PyObject* encode_dict(PyObject *object); +static PyObject* encode_tuple(PyObject *object, PyObject *fallback); +static PyObject* encode_list(PyObject *object, PyObject *fallback); +static PyObject* encode_dict(PyObject *object, PyObject *fallback); static PyObject* decode_json(JSONData *jsondata); static PyObject* decode_null(JSONData *jsondata); @@ -799,7 +799,7 @@ encode_unicode(PyObject *unicode) */ static PyObject* -encode_tuple(PyObject *tuple) +encode_tuple(PyObject *tuple, PyObject *fallback) { Py_ssize_t i, n; PyObject *s, *temp; @@ -816,7 +816,7 @@ encode_tuple(PyObject *tuple) /* Do repr() on each element. */ for (i = 0; i < n; ++i) { - s = encode_object(v->ob_item[i]); + s = encode_object(v->ob_item[i], fallback); if (s == NULL) goto Done; PyTuple_SET_ITEM(pieces, i, s); @@ -864,7 +864,7 @@ encode_tuple(PyObject *tuple) * represented in JSON. */ static PyObject* -encode_list(PyObject *list) +encode_list(PyObject *list, PyObject *fallback) { Py_ssize_t i; PyObject *s, *temp; @@ -893,7 +893,7 @@ encode_list(PyObject *list) * so must refetch the list size on each iteration. */ for (i = 0; i < v->ob_size; ++i) { int status; - s = encode_object(v->ob_item[i]); + s = encode_object(v->ob_item[i], fallback); if (s == NULL) goto Done; status = PyList_Append(pieces, s); @@ -947,7 +947,7 @@ encode_list(PyObject *list) * be represented in JSON. */ static PyObject* -encode_dict(PyObject *dict) +encode_dict(PyObject *dict, PyObject *fallback) { Py_ssize_t i; PyObject *s, *temp, *colon = NULL; @@ -991,9 +991,9 @@ encode_dict(PyObject *dict) /* Prevent repr from deleting value during key format. */ Py_INCREF(value); - s = encode_object(key); + s = encode_object(key, fallback); PyString_Concat(&s, colon); - PyString_ConcatAndDel(&s, encode_object(value)); + PyString_ConcatAndDel(&s, encode_object(value, fallback)); Py_DECREF(value); if (s == NULL) goto Done; @@ -1039,7 +1039,7 @@ encode_dict(PyObject *dict) static PyObject* -encode_object(PyObject *object) +encode_object(PyObject *object, PyObject *fallback) { if (object == Py_True) { return PyString_FromString("true"); @@ -1070,25 +1070,39 @@ encode_object(PyObject *object) PyObject *result; if (Py_EnterRecursiveCall(" while encoding a JSON array from a Python list")) return NULL; - result = encode_list(object); + result = encode_list(object, fallback); Py_LeaveRecursiveCall(); return result; } else if (PyTuple_Check(object)) { PyObject *result; if (Py_EnterRecursiveCall(" while encoding a JSON array from a Python tuple")) return NULL; - result = encode_tuple(object); + result = encode_tuple(object, fallback); Py_LeaveRecursiveCall(); return result; } else if (PyDict_Check(object)) { // use PyMapping_Check(object) instead? -Dan PyObject *result; if (Py_EnterRecursiveCall(" while encoding a JSON object")) return NULL; - result = encode_dict(object); + result = encode_dict(object, fallback); + Py_LeaveRecursiveCall(); + return result; + } else if (fallback) { + PyObject *args, *resolve, *result; + if (Py_EnterRecursiveCall(" while encoding a non-primitive Python object")) + return NULL; + args = PyTuple_Pack(1, object); + resolve = PyObject_CallObject(fallback, args); + Py_DECREF(args); + result = PyErr_Occurred() ? NULL : encode_object(resolve, fallback); + Py_XDECREF(resolve); Py_LeaveRecursiveCall(); return result; } else { - PyErr_SetString(JSON_EncodeError, "object is not JSON encodable"); + PyObject *repr = PyObject_Repr(object); + PyErr_Format(JSON_EncodeError, "object %s is not JSON encodable", + PyString_AsString(repr)); + Py_DECREF(repr); return NULL; } } @@ -1097,9 +1111,21 @@ encode_object(PyObject *object) /* Encode object into its JSON representation */ static PyObject* -JSON_encode(PyObject *self, PyObject *object) +JSON_encode(PyObject *self, PyObject *args, PyObject *kwargs) { - return encode_object(object); + static char *kwlist[] = {"obj", "default", NULL}; + PyObject *object, *fallback = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O:decode", + kwlist, &object, &fallback + )) return NULL; + if (fallback && PyObject_Not(fallback)) fallback = NULL; + if (fallback && !PyCallable_Check(fallback)) { + PyErr_Format(PyExc_ValueError, + "The 'default' argument %s is not callable", + PyObject_Repr(fallback)); + return NULL; + }; + return encode_object(object, fallback); } @@ -1159,8 +1185,11 @@ JSON_decode(PyObject *self, PyObject *args, PyObject *kwargs) /* List of functions defined in the module */ static PyMethodDef cjson_methods[] = { - {"encode", (PyCFunction)JSON_encode, METH_O, - PyDoc_STR("encode(object) -> generate the JSON representation for object.")}, + {"encode", (PyCFunction)JSON_encode, METH_VARARGS|METH_KEYWORDS, + PyDoc_STR("encode(object, default=Null) -> generate the JSON representation for object.\n" + "The optional argument `default' is function that gets called for objects\n" + "that can’t otherwise be serialized. It should return a JSON encodable\n" + "version of the object or raise cjson.EncodeError.")}, {"decode", (PyCFunction)JSON_decode, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("decode(string, all_unicode=False) -> parse the JSON representation into\n" diff --git a/jsontest.py b/jsontest.py index ad9a750..0607220 100755 --- a/jsontest.py +++ b/jsontest.py @@ -20,6 +20,9 @@ ## License along with this library; if not, write to the Free Software ## Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +from datetime import datetime +from decimal import Decimal + import unittest import cjson @@ -141,7 +144,7 @@ def testReadBadArray(self): def doReadBadArray(self): cjson.decode('[1,2,3,,]') - + def testReadBadObjectSyntax(self): self.assertRaises(_exception, self.doReadBadObjectSyntax) @@ -159,7 +162,7 @@ def testReadIntegerValue(self): def testReadNegativeIntegerValue(self): obj = cjson.decode('{ "key" : -44 }') self.assertEqual({ "key" : -44 }, obj) - + def testReadFloatValue(self): obj = cjson.decode('{ "age" : 44.5 }') self.assertEqual({ "age" : 44.5 }, obj) @@ -176,7 +179,7 @@ def doReadBadNumber(self): def testReadSmallObject(self): obj = cjson.decode('{ "name" : "Patrick", "age":44} ') - self.assertEqual({ "age" : 44, "name" : "Patrick" }, obj) + self.assertEqual({ "age" : 44, "name" : "Patrick" }, obj) def testReadEmptyArray(self): obj = cjson.decode('[]') @@ -331,7 +334,18 @@ def testWriteLongUnicode(self): u'\u1234\u1234\u1234\u1234\u1234\u1234') self.assertEqual(r'"\ud834\udd1e\ud834\udd1e\ud834\udd1e\ud834\udd1e' r'\u1234\u1234\u1234\u1234\u1234\u1234"', s) - + + def testWriteCustomObject(self): + def fallback(obj): + if isinstance(obj, Decimal): + return float(obj) + raise cjson.EncodeError(obj) + with self.assertRaises(cjson.EncodeError): + cjson.encode(Decimal(1.23)) + self.assertEqual(cjson.encode(Decimal(1.23), fallback), '1.23') + with self.assertRaises(cjson.EncodeError): + cjson.encode(datetime.now(), fallback) + def main(): unittest.main()