diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py index 4cbfd7ad8b1e48..6904572d095d31 100644 --- a/Lib/test/test_struct.py +++ b/Lib/test/test_struct.py @@ -584,8 +584,24 @@ def test_Struct_reinitialization(self): # Issue 9422: there was a memory leak when reinitializing a # Struct instance. This test can be used to detect the leak # when running with regrtest -L. - s = struct.Struct('i') - s.__init__('ii') + s = struct.Struct('>h') + s.__init__('>hh') + self.assertEqual(s.format, '>hh') + packed = b'\x00\x01\x00\x02' + self.assertEqual(s.pack(1, 2), packed) + self.assertEqual(s.unpack(packed), (1, 2)) + + with self.assertRaises(UnicodeEncodeError): + s.__init__('\udc00') + self.assertEqual(s.format, '>hh') + self.assertEqual(s.pack(1, 2), packed) + self.assertEqual(s.unpack(packed), (1, 2)) + + with self.assertRaises(struct.error): + s.__init__('$') + self.assertEqual(s.format, '>hh') + self.assertEqual(s.pack(1, 2), packed) + self.assertEqual(s.unpack(packed), (1, 2)) def check_sizeof(self, format_str, number_of_codes): # The size of 'PyStructObject' diff --git a/Modules/_struct.c b/Modules/_struct.c index dcc3c7ec63e9e0..c2f7b1fe0e800a 100644 --- a/Modules/_struct.c +++ b/Modules/_struct.c @@ -1620,11 +1620,11 @@ align(Py_ssize_t size, char c, const formatdef *e) /* calculate the size of a format string */ static int -prepare_s(PyStructObject *self) +prepare_s(PyStructObject *self, PyObject *format) { const formatdef *f; const formatdef *e; - formatcode *codes; + formatcode *codes, *codes0; const char *s; const char *fmt; @@ -1634,8 +1634,8 @@ prepare_s(PyStructObject *self) _structmodulestate *state = get_struct_state_structinst(self); - fmt = PyBytes_AS_STRING(self->s_format); - if (strlen(fmt) != (size_t)PyBytes_GET_SIZE(self->s_format)) { + fmt = PyBytes_AS_STRING(format); + if (strlen(fmt) != (size_t)PyBytes_GET_SIZE(format)) { PyErr_SetString(state->StructError, "embedded null character"); return -1; @@ -1711,13 +1711,7 @@ prepare_s(PyStructObject *self) PyErr_NoMemory(); return -1; } - /* Free any s_codes value left over from a previous initialization. */ - if (self->s_codes != NULL) - PyMem_Free(self->s_codes); - self->s_codes = codes; - self->s_size = size; - self->s_len = len; - + codes0 = codes; s = fmt; size = 0; while ((c = *s++) != '\0') { @@ -1757,6 +1751,14 @@ prepare_s(PyStructObject *self) codes->size = 0; codes->repeat = 0; + /* Free any s_codes value left over from a previous initialization. */ + if (self->s_codes != NULL) + PyMem_Free(self->s_codes); + self->s_codes = codes0; + self->s_size = size; + self->s_len = len; + Py_XSETREF(self->s_format, Py_NewRef(format)); + return 0; overflow: @@ -1820,9 +1822,8 @@ Struct___init___impl(PyStructObject *self, PyObject *format) return -1; } - Py_SETREF(self->s_format, format); - - ret = prepare_s(self); + ret = prepare_s(self, format); + Py_DECREF(format); return ret; }