Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions msgpack/_unpacker.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ cdef class Unpacker:
self.buf = NULL

def __dealloc__(self):
unpack_clear(&self.ctx)
PyMem_Free(self.buf)
self.buf = NULL

Expand Down
8 changes: 8 additions & 0 deletions msgpack/unpack_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ static inline PyObject* unpack_data(unpack_context* ctx)

static inline void unpack_clear(unpack_context *ctx)
{
unsigned int i;
for (i = 1; i < ctx->top; i++) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about use i = 0 here and remove Py_CLEARE(ctx->stack[0].obj) at the bottom?
Is map_key at stack[0] safe?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I obviously have less knowledge about this than you so let me know if I'm wrong, but my understanding is that unpack_init sets top to 0 but still pushes something into stack[0]:

ctx->stack[0].obj = unpack_callback_root(&ctx->user);

As a result, if we looped from [0; top) and top == 0 then we wouldn't free stack[0].obj as far as I can tell.


Regarding map_key at stack[0]: I think we may need to free it depending on the case. If top is 0 then we never need to free it; if top >= 1 then we would need to check for CT_MAP_KEY like we do in the loop.


So bottomline: we can probably iterate from 0 to ctx->top > 0 ? ctx->top : 1 to capture both cases. On top of that, we would need to check whether ctx->top == 0 before the CT_MAP_KEY check (since as far as I can tell, ctx->stack[0].ct would be uninitialised so if we're unlucky, we could accidentally call Py_CLEAR(ctx->stack[0].map_key) which would not be safe.

But if we're adding all that logic, what may actually be simpler is a mix of all the things:

static inline void unpack_clear(unpack_context *ctx)
{
    unsigned int i;
    // The loop captures the case where we did push at least one thing to the stack
    for (i = 0; i < ctx->top; i++) {
        Py_CLEAR(ctx->stack[i].obj);
        /* map_key holds a live reference only while waiting for the value */
        if (ctx->stack[i].ct == CT_MAP_VALUE) {
            Py_CLEAR(ctx->stack[i].map_key);
        }
    }

    // This captures the case where we did not push anything to the stack
    // Clear again at 0, which is safe (it just sets the pointer to NULL => no-op on second call)
    Py_CLEAR(ctx->stack[0].obj);
}

Py_CLEAR(ctx->stack[i].obj);
/* map_key holds a live reference only while waiting for the value */
if (ctx->stack[i].ct == CT_MAP_VALUE) {
Py_CLEAR(ctx->stack[i].map_key);
}
}
Py_CLEAR(ctx->stack[0].obj);
}
Comment thread
KowalskiThomas marked this conversation as resolved.
Comment thread
KowalskiThomas marked this conversation as resolved.

Expand Down
35 changes: 35 additions & 0 deletions test/test_except.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python

import datetime
import gc
import tracemalloc

from pytest import raises

Expand Down Expand Up @@ -80,6 +82,39 @@ def test_invalidvalue():
unpackb(b"\x91" * 3000) # nested fixarray(len=1)


def test_no_memory_leak_on_nested_invalid_tag() -> None:
"""Regression test: unpacking nested arrays containing an invalid tag must not leak objects."""

kwargs: dict = {
"raw": False,
"strict_map_key": False,
"max_array_len": 1 << 20,
"max_map_len": 1 << 20,
}
n = 1000

for depth in range(1, 15):
data = bytes([0x91] * depth + [0xC1])

gc.collect()
tracemalloc.start()
s1 = tracemalloc.take_snapshot()

for _ in range(n):
try:
unpackb(data, **kwargs)
except Exception:
pass

gc.collect()
s2 = tracemalloc.take_snapshot()
tracemalloc.stop()

leaked = sum(s.count_diff for s in s2.compare_to(s1, "lineno") if s.count_diff > 0)
per_call = leaked / n
assert per_call < 1.0, f"depth={depth}: {per_call:.2f} leaked objects/call (expected < 1)"


def test_strict_map_key():
valid = {"unicode": 1, b"bytes": 2}
packed = packb(valid, use_bin_type=True)
Expand Down