diff --git a/sentry_sdk/scope.py b/sentry_sdk/scope.py index 454a82db85..5dc8adaeef 100644 --- a/sentry_sdk/scope.py +++ b/sentry_sdk/scope.py @@ -696,6 +696,13 @@ def get_active_propagation_context(self) -> "PropagationContext": isolation_scope._propagation_context = PropagationContext() return isolation_scope._propagation_context + def set_custom_sampling_context( + self, custom_sampling_context: "dict[str, Any]" + ) -> None: + self.get_active_propagation_context()._set_custom_sampling_context( + custom_sampling_context + ) + def clear(self) -> None: """Clears the entire scope.""" self._level: "Optional[LogLevelStr]" = None diff --git a/sentry_sdk/tracing_utils.py b/sentry_sdk/tracing_utils.py index 90b4d84389..0bdb819b04 100644 --- a/sentry_sdk/tracing_utils.py +++ b/sentry_sdk/tracing_utils.py @@ -417,6 +417,7 @@ class PropagationContext: "parent_span_id", "parent_sampled", "baggage", + "custom_sampling_context", ) def __init__( @@ -450,6 +451,8 @@ def __init__( if baggage is None and dynamic_sampling_context is not None: self.baggage = Baggage(dynamic_sampling_context) + self.custom_sampling_context: "Optional[dict[str, Any]]" = None + @classmethod def from_incoming_data( cls, incoming_data: "Dict[str, Any]" @@ -537,6 +540,11 @@ def update(self, other_dict: "Dict[str, Any]") -> None: except AttributeError: pass + def _set_custom_sampling_context( + self, custom_sampling_context: "dict[str, Any]" + ) -> None: + self.custom_sampling_context = custom_sampling_context + def __repr__(self) -> str: return "".format( self._trace_id, @@ -1413,13 +1421,18 @@ def _make_sampling_decision( traces_sampler_defined = callable(client.options.get("traces_sampler")) if traces_sampler_defined: sampling_context = { - "name": name, - "trace_id": propagation_context.trace_id, - "parent_span_id": propagation_context.parent_span_id, - "parent_sampled": propagation_context.parent_sampled, - "attributes": dict(attributes) if attributes else {}, + "span_context": { + "name": name, + "trace_id": propagation_context.trace_id, + "parent_span_id": propagation_context.parent_span_id, + "parent_sampled": propagation_context.parent_sampled, + "attributes": dict(attributes) if attributes else {}, + }, } + if propagation_context.custom_sampling_context: + sampling_context.update(propagation_context.custom_sampling_context) + sample_rate = client.options["traces_sampler"](sampling_context) else: if propagation_context.parent_sampled is not None: diff --git a/tests/tracing/test_span_streaming.py b/tests/tracing/test_span_streaming.py index 2adaddbdd7..cfcd928631 100644 --- a/tests/tracing/test_span_streaming.py +++ b/tests/tracing/test_span_streaming.py @@ -120,7 +120,7 @@ def test_span_sampled_when_created(sentry_init, capture_envelopes): # at start_span() time def traces_sampler(sampling_context): - assert "delayed_attribute" not in sampling_context["attributes"] + assert "delayed_attribute" not in sampling_context["span_context"]["attributes"] return 1.0 sentry_init( @@ -169,9 +169,11 @@ def test_start_span_attributes(sentry_init, capture_envelopes): def test_start_span_attributes_in_traces_sampler(sentry_init, capture_envelopes): def traces_sampler(sampling_context): - assert "attributes" in sampling_context - assert "my_attribute" in sampling_context["attributes"] - assert sampling_context["attributes"]["my_attribute"] == "my_value" + assert "attributes" in sampling_context["span_context"] + assert "my_attribute" in sampling_context["span_context"]["attributes"] + assert ( + sampling_context["span_context"]["attributes"]["my_attribute"] == "my_value" + ) return 1.0 sentry_init( @@ -202,16 +204,16 @@ def test_sampling_context(sentry_init, capture_envelopes): def traces_sampler(sampling_context): nonlocal received_trace_id - assert "trace_id" in sampling_context - received_trace_id = sampling_context["trace_id"] + assert "trace_id" in sampling_context["span_context"] + received_trace_id = sampling_context["span_context"]["trace_id"] - assert "parent_span_id" in sampling_context - assert sampling_context["parent_span_id"] is None + assert "parent_span_id" in sampling_context["span_context"] + assert sampling_context["span_context"]["parent_span_id"] is None - assert "parent_sampled" in sampling_context - assert sampling_context["parent_sampled"] is None + assert "parent_sampled" in sampling_context["span_context"] + assert sampling_context["span_context"]["parent_sampled"] is None - assert "attributes" in sampling_context + assert "attributes" in sampling_context["span_context"] return 1.0 @@ -233,6 +235,62 @@ def traces_sampler(sampling_context): assert len(spans) == 1 +def test_custom_sampling_context(sentry_init): + class MyClass: ... + + my_class = MyClass() + + def traces_sampler(sampling_context): + assert "class" in sampling_context + assert "string" in sampling_context + assert sampling_context["class"] == my_class + assert sampling_context["string"] == "my string" + return 1.0 + + sentry_init( + traces_sampler=traces_sampler, + _experiments={"trace_lifecycle": "stream"}, + ) + + sentry_sdk.get_current_scope().set_custom_sampling_context( + { + "class": my_class, + "string": "my string", + } + ) + + with sentry_sdk.traces.start_span(name="span"): + ... + + +def test_custom_sampling_context_update_to_context_value_persists(sentry_init): + def traces_sampler(sampling_context): + if sampling_context["span_context"]["attributes"]["first"] is True: + assert sampling_context["custom_value"] == 1 + else: + assert sampling_context["custom_value"] == 2 + return 1.0 + + sentry_init( + traces_sampler=traces_sampler, + _experiments={"trace_lifecycle": "stream"}, + ) + + sentry_sdk.traces.new_trace() + + sentry_sdk.get_current_scope().set_custom_sampling_context({"custom_value": 1}) + + with sentry_sdk.traces.start_span(name="span", attributes={"first": True}): + ... + + sentry_sdk.traces.new_trace() + + sentry_sdk.get_current_scope().set_custom_sampling_context({"custom_value": 2}) + + with sentry_sdk.traces.start_span(name="span", attributes={"first": False}): + ... + + def test_span_attributes(sentry_init, capture_envelopes): sentry_init( traces_sample_rate=1.0, @@ -305,10 +363,10 @@ class Class: def test_traces_sampler_drops_span(sentry_init, capture_envelopes): def traces_sampler(sampling_context): - assert "attributes" in sampling_context - assert "drop" in sampling_context["attributes"] + assert "attributes" in sampling_context["span_context"] + assert "drop" in sampling_context["span_context"]["attributes"] - if sampling_context["attributes"]["drop"] is True: + if sampling_context["span_context"]["attributes"]["drop"] is True: return 0.0 return 1.0 @@ -342,7 +400,7 @@ def test_traces_sampler_called_once_per_segment(sentry_init): def traces_sampler(sampling_context): nonlocal traces_sampler_called, span_name_in_traces_sampler traces_sampler_called += 1 - span_name_in_traces_sampler = sampling_context["name"] + span_name_in_traces_sampler = sampling_context["span_context"]["name"] return 1.0 sentry_init(