diff --git a/netra/exporters/filtering_span_exporter.py b/netra/exporters/filtering_span_exporter.py index cea41e1..8df9c1e 100644 --- a/netra/exporters/filtering_span_exporter.py +++ b/netra/exporters/filtering_span_exporter.py @@ -9,7 +9,7 @@ from netra.exporters.utils import add_blocked_trace_id, get_trace_id, is_trace_id_blocked, is_trial_blocked from netra.processors.local_filtering_span_processor import ( - BLOCKED_LOCAL_PARENT_MAP, + blocked_local_parent_map_snapshot, ) logger = logging.getLogger(__name__) @@ -113,12 +113,7 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: # Merge with registry of locally blocked spans captured by processor to handle # cases where children export before their blocked parent (SimpleSpanProcessor) - merged_map: Dict[Any, Any] = {} - try: - if BLOCKED_LOCAL_PARENT_MAP: - merged_map.update(BLOCKED_LOCAL_PARENT_MAP) - except Exception: - pass + merged_map: Dict[Any, Any] = blocked_local_parent_map_snapshot() merged_map.update(blocked_parent_map) if merged_map: diff --git a/netra/processors/local_filtering_span_processor.py b/netra/processors/local_filtering_span_processor.py index 89f3d4c..dd832e3 100644 --- a/netra/processors/local_filtering_span_processor.py +++ b/netra/processors/local_filtering_span_processor.py @@ -1,7 +1,8 @@ import json import logging +import threading from contextlib import contextmanager -from typing import List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from opentelemetry import baggage from opentelemetry import context as otel_context @@ -15,9 +16,43 @@ # Attribute key to copy resolved local blocked patterns onto each span _LOCAL_BLOCKED_SPANS_ATTR_KEY = "netra.local_blocked_spans" -# Registry of locally blocked spans: span_id -> parent_context -# This lets exporters reparent children reliably even when children export before parents -BLOCKED_LOCAL_PARENT_MAP: dict[object, object] = {} +# Registry of locally blocked spans: span_id -> parent_context. +# This lets exporters reparent children reliably even when children export +# before parents. All access must go through the accessor functions below +# to ensure thread-safety. +_blocked_local_parent_map: Dict[Any, Any] = {} +_blocked_local_parent_lock = threading.Lock() + + +def blocked_local_parent_map_put(span_id: Any, parent_context: Any) -> None: + """Register a locally-blocked span's parent context. + + Args: + span_id: The span ID of the blocked span. + parent_context: The parent ``SpanContext`` to reparent children to. + """ + with _blocked_local_parent_lock: + _blocked_local_parent_map[span_id] = parent_context + + +def blocked_local_parent_map_pop(span_id: Any) -> None: + """Remove a span entry from the blocked-parent registry. + + Args: + span_id: The span ID to remove. + """ + with _blocked_local_parent_lock: + _blocked_local_parent_map.pop(span_id, None) + + +def blocked_local_parent_map_snapshot() -> Dict[Any, Any]: + """Return a shallow copy of the blocked-parent registry. + + Returns: + A dict copy safe to iterate without holding the lock. + """ + with _blocked_local_parent_lock: + return dict(_blocked_local_parent_map) class LocalFilteringSpanProcessor(SpanProcessor): # type: ignore[misc] @@ -62,7 +97,7 @@ def on_start(self, span: trace.Span, parent_context: Optional[otel_context.Conte parent_span.get_span_context() if hasattr(parent_span, "get_span_context") else None ) if span_id is not None and parent_span_context is not None: - BLOCKED_LOCAL_PARENT_MAP[span_id] = parent_span_context + blocked_local_parent_map_put(span_id, parent_span_context) # Mark on the span for visibility/debugging try: span.set_attribute("netra.local_blocked", True) @@ -87,7 +122,7 @@ def on_end(self, span: trace.Span) -> None: # noqa: D401 ctx = getattr(span, "context", None) span_id = getattr(ctx, "span_id", None) if ctx else None if span_id is not None: - BLOCKED_LOCAL_PARENT_MAP.pop(span_id, None) + blocked_local_parent_map_pop(span_id) except Exception: pass return diff --git a/netra/session_manager.py b/netra/session_manager.py index 9522016..d1b8766 100644 --- a/netra/session_manager.py +++ b/netra/session_manager.py @@ -1,5 +1,6 @@ import json import logging +import threading from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -20,7 +21,14 @@ class ConversationType(str, Enum): class SessionManager: - """Manages session and user context for applications.""" + """Manages session and user context for applications. + + All mutable class-level state is protected by ``_lock`` so that + concurrent threads (or ``asyncio.to_thread`` calls) cannot corrupt + the internal stacks and registries. + """ + + _lock = threading.Lock() # Class variable to track the current span _current_span: Optional[trace.Span] = None @@ -46,7 +54,8 @@ def set_current_span(cls, span: Optional[trace.Span]) -> None: Args: span: The current span to store """ - cls._current_span = span + with cls._lock: + cls._current_span = span @classmethod def get_current_span(cls) -> Optional[trace.Span]: @@ -56,7 +65,8 @@ def get_current_span(cls) -> Optional[trace.Span]: Returns: The stored current span or None if not set """ - return cls._current_span + with cls._lock: + return cls._current_span @classmethod def register_span(cls, name: str, span: trace.Span) -> None: @@ -68,13 +78,13 @@ def register_span(cls, name: str, span: trace.Span) -> None: span: The span to register """ try: - stack = cls._spans_by_name.get(name) - if stack is None: - cls._spans_by_name[name] = [span] - else: - stack.append(span) - # Track globally as active - cls._active_spans.append(span) + with cls._lock: + stack = cls._spans_by_name.get(name) + if stack is None: + cls._spans_by_name[name] = [span] + else: + stack.append(span) + cls._active_spans.append(span) except Exception: logger.exception("Failed to register span '%s'", name) @@ -88,21 +98,20 @@ def unregister_span(cls, name: str, span: trace.Span) -> None: span: The span to unregister """ try: - stack = cls._spans_by_name.get(name) - if not stack: - return - # Remove the last matching instance (normal case) - for i in range(len(stack) - 1, -1, -1): - if stack[i] is span: - stack.pop(i) - break - if not stack: - cls._spans_by_name.pop(name, None) - # Also remove from global active list (remove last matching instance) - for i in range(len(cls._active_spans) - 1, -1, -1): - if cls._active_spans[i] is span: - cls._active_spans.pop(i) - break + with cls._lock: + stack = cls._spans_by_name.get(name) + if not stack: + return + for i in range(len(stack) - 1, -1, -1): + if stack[i] is span: + stack.pop(i) + break + if not stack: + cls._spans_by_name.pop(name, None) + for i in range(len(cls._active_spans) - 1, -1, -1): + if cls._active_spans[i] is span: + cls._active_spans.pop(i) + break except Exception: logger.exception("Failed to unregister span '%s'", name) @@ -131,10 +140,11 @@ def get_span_by_name(cls, name: str) -> Optional[trace.Span]: Returns: The most recently registered span with the given name, or None if not found """ - stack = cls._spans_by_name.get(name) - if stack: - return stack[-1] - return None + with cls._lock: + stack = cls._spans_by_name.get(name) + if stack: + return stack[-1] + return None @classmethod def push_entity(cls, entity_type: str, entity_name: str) -> None: @@ -145,14 +155,15 @@ def push_entity(cls, entity_type: str, entity_name: str) -> None: entity_type: Type of entity (workflow, task, agent, span) entity_name: Name of the entity """ - if entity_type == "workflow": - cls._workflow_stack.append(entity_name) - elif entity_type == "task": - cls._task_stack.append(entity_name) - elif entity_type == "agent": - cls._agent_stack.append(entity_name) - elif entity_type == "span": - cls._span_stack.append(entity_name) + with cls._lock: + if entity_type == "workflow": + cls._workflow_stack.append(entity_name) + elif entity_type == "task": + cls._task_stack.append(entity_name) + elif entity_type == "agent": + cls._agent_stack.append(entity_name) + elif entity_type == "span": + cls._span_stack.append(entity_name) @classmethod def pop_entity(cls, entity_type: str) -> Optional[str]: @@ -165,15 +176,16 @@ def pop_entity(cls, entity_type: str) -> Optional[str]: Returns: Entity name or None if stack is empty """ - if entity_type == "workflow" and cls._workflow_stack: - return cls._workflow_stack.pop() - elif entity_type == "task" and cls._task_stack: - return cls._task_stack.pop() - elif entity_type == "agent" and cls._agent_stack: - return cls._agent_stack.pop() - elif entity_type == "span" and cls._span_stack: - return cls._span_stack.pop() - return None + with cls._lock: + if entity_type == "workflow" and cls._workflow_stack: + return cls._workflow_stack.pop() + elif entity_type == "task" and cls._task_stack: + return cls._task_stack.pop() + elif entity_type == "agent" and cls._agent_stack: + return cls._agent_stack.pop() + elif entity_type == "span" and cls._span_stack: + return cls._span_stack.pop() + return None @classmethod def get_current_entity_attributes(cls) -> Dict[str, str]: @@ -183,33 +195,31 @@ def get_current_entity_attributes(cls) -> Dict[str, str]: Returns: Dictionary of entity attributes to add to spans """ - attributes = {} + with cls._lock: + attributes = {} - # Add current workflow if exists - if cls._workflow_stack: - attributes[f"{Config.LIBRARY_NAME}.workflow.name"] = cls._workflow_stack[-1] + if cls._workflow_stack: + attributes[f"{Config.LIBRARY_NAME}.workflow.name"] = cls._workflow_stack[-1] - # Add current task if exists - if cls._task_stack: - attributes[f"{Config.LIBRARY_NAME}.task.name"] = cls._task_stack[-1] + if cls._task_stack: + attributes[f"{Config.LIBRARY_NAME}.task.name"] = cls._task_stack[-1] - # Add current agent if exists - if cls._agent_stack: - attributes[f"{Config.LIBRARY_NAME}.agent.name"] = cls._agent_stack[-1] + if cls._agent_stack: + attributes[f"{Config.LIBRARY_NAME}.agent.name"] = cls._agent_stack[-1] - # Add current span if exists - if cls._span_stack: - attributes[f"{Config.LIBRARY_NAME}.span.name"] = cls._span_stack[-1] + if cls._span_stack: + attributes[f"{Config.LIBRARY_NAME}.span.name"] = cls._span_stack[-1] - return attributes + return attributes @classmethod def clear_entity_stacks(cls) -> None: """Clear all entity stacks.""" - cls._workflow_stack.clear() - cls._task_stack.clear() - cls._agent_stack.clear() - cls._span_stack.clear() + with cls._lock: + cls._workflow_stack.clear() + cls._task_stack.clear() + cls._agent_stack.clear() + cls._span_stack.clear() @classmethod def get_stack_info(cls) -> Dict[str, List[str]]: @@ -219,12 +229,13 @@ def get_stack_info(cls) -> Dict[str, List[str]]: Returns: Dictionary containing all stack contents """ - return { - "workflows": cls._workflow_stack.copy(), - "tasks": cls._task_stack.copy(), - "agents": cls._agent_stack.copy(), - "spans": cls._span_stack.copy(), - } + with cls._lock: + return { + "workflows": cls._workflow_stack.copy(), + "tasks": cls._task_stack.copy(), + "agents": cls._agent_stack.copy(), + "spans": cls._span_stack.copy(), + } @staticmethod def set_session_context( @@ -318,13 +329,15 @@ def add_conversation(cls, conversation_type: ConversationType, role: str, conten span = trace.get_current_span() if not (span and getattr(span, "is_recording", lambda: False)()): # Fallback: use the most recent active span from SessionManager - if not cls._active_spans: + with cls._lock: + active_snapshot = list(cls._active_spans) + + if not active_snapshot: logger.warning("No active span to add conversation attribute.") return - # Find the most recent *recording* span (the last item can be a finished span) recording_span: Optional[trace.Span] = None - for span in reversed(cls._active_spans): + for span in reversed(active_snapshot): try: if span and getattr(span, "is_recording", lambda: False)(): recording_span = span