Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions ddtrace/internal/wrapping/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from bytecode import Bytecode

from ddtrace.internal.assembly import Assembly
from ddtrace.internal.forksafe import Lock
from ddtrace.internal.utils.inspection import link_function_to_code
from ddtrace.internal.wrapping import WrappedFunction
from ddtrace.internal.wrapping import Wrapper
from ddtrace.internal.wrapping import is_wrapped_with
from ddtrace.internal.wrapping import unwrap
from ddtrace.internal.wrapping import wrap


T = t.TypeVar("T")
Expand Down Expand Up @@ -406,6 +412,44 @@ def unwrap(self) -> None:
_UniversalWrappingContext.extract(f).unregister(self)


class LazyWrappingContext(WrappingContext):
def __init__(self, f: FunctionType):
super().__init__(f)

self._trampoline: t.Optional[Wrapper] = None
self._trampoline_lock = Lock()

def wrap(self) -> None:
"""Perform the bytecode wrapping on first invocation."""
with (tl := self._trampoline_lock):
if self._trampoline is not None:
return

def trampoline(_, args, kwargs):
with tl:
f = t.cast(WrappedFunction, self.__wrapped__)
if is_wrapped_with(self.__wrapped__, trampoline):
f = unwrap(f, trampoline)
super(LazyWrappingContext, self).wrap()
return f(*args, **kwargs)

wrap(self.__wrapped__, trampoline)

self._trampoline = trampoline

def unwrap(self) -> None:
with self._trampoline_lock:
if self._trampoline is None:
return

if self.is_wrapped(self.__wrapped__):
super().unwrap()
else:
unwrap(t.cast(WrappedFunction, self.__wrapped__), self._trampoline)

self._trampoline = None


class ContextWrappedFunction(Protocol):
"""A wrapped function."""

Expand Down
40 changes: 40 additions & 0 deletions tests/internal/test_wrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ddtrace.internal.wrapping import is_wrapped_with
from ddtrace.internal.wrapping import unwrap
from ddtrace.internal.wrapping import wrap
from ddtrace.internal.wrapping.context import LazyWrappingContext
from ddtrace.internal.wrapping.context import WrappingContext
from ddtrace.internal.wrapping.context import _UniversalWrappingContext

Expand Down Expand Up @@ -926,3 +927,42 @@ def foo():

new_method_count = len([_ for _ in gc.get_objects() if type(_).__name__ == "method"])
assert new_method_count <= method_count + 1


def test_wrapping_context_lazy():
free = 42

def foo():
return free

class DummyLazyWrappingContext(LazyWrappingContext):
def __init__(self, f):
super().__init__(f)

self.count = 0

def __enter__(self):
self.count += 1
return super().__enter__()

(wc := DummyLazyWrappingContext(foo)).wrap()

assert not DummyLazyWrappingContext.is_wrapped(foo)

for _ in range(n := 10):
assert foo() == free

assert DummyLazyWrappingContext.is_wrapped(foo)

assert wc.count == n

wc.count = 0

wc.unwrap()

for _ in range(10):
assert not DummyLazyWrappingContext.is_wrapped(foo)

assert foo() == free

assert wc.count == 0
Loading