diff --git a/src/dispatch/signal/service.py b/src/dispatch/signal/service.py index fe2419b0f0c7..ca745e357611 100644 --- a/src/dispatch/signal/service.py +++ b/src/dispatch/signal/service.py @@ -971,9 +971,13 @@ def get_cases_for_signal_by_resolution_reason( def get_signal_stats( - *, db_session: Session, entity_value: str, entity_type_id: int, num_days: int | None + *, db_session: Session, entity_value: str, entity_type_id: int, signal_id: int | None = None, num_days: int | None = None ) -> Optional[SignalStats]: - """Gets a signal statistics for a given named entity and type.""" + """ + Gets signal statistics for a given named entity and type. + + If signal_id is provided, only returns stats for that specific signal definition. + """ entity_subquery = ( db_session.query( func.jsonb_build_array( @@ -1027,6 +1031,7 @@ def get_signal_stats( and_( Entity.value == entity_value, Entity.entity_type_id == entity_type_id, + SignalInstance.signal_id == signal_id if signal_id else True, SignalInstance.created_at >= date_threshold if date_threshold else True, ) ) diff --git a/src/dispatch/signal/views.py b/src/dispatch/signal/views.py index 276a497630ff..c05f33397802 100644 --- a/src/dispatch/signal/views.py +++ b/src/dispatch/signal/views.py @@ -290,7 +290,7 @@ def return_signal_stats( entity_type_id: int = Query(..., description="The ID of the entity type"), num_days: int = Query(None, description="The number of days to look back"), ): - """Gets a signal statistics given a named entity and entity type id.""" + """Gets signal statistics given a named entity and entity type id.""" signal_data = get_signal_stats( db_session=db_session, entity_value=entity_value, @@ -300,6 +300,32 @@ def return_signal_stats( return signal_data +@router.get("/{signal_id}/stats", response_model=SignalStats) +def return_single_signal_stats( + db_session: DbSession, + signal_id: Union[str, PrimaryKey], + entity_value: str = Query(..., description="The name of the entity"), + entity_type_id: int = Query(..., description="The ID of the entity type"), + num_days: int = Query(None, description="The number of days to look back"), +): + """Gets signal statistics for a specific signal given a named entity and entity type id.""" + signal = get_by_primary_or_external_id(db_session=db_session, signal_id=signal_id) + if not signal: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=[{"msg": "A signal with this id does not exist."}], + ) + + signal_data = get_signal_stats( + db_session=db_session, + entity_value=entity_value, + entity_type_id=entity_type_id, + signal_id=signal.id, + num_days=num_days, + ) + return signal_data + + @router.get("/{signal_id}", response_model=SignalRead) def get_signal(db_session: DbSession, signal_id: Union[str, PrimaryKey]): """Gets a signal by its id.""" diff --git a/tests/signal/test_signal_data_service.py b/tests/signal/test_signal_stats_service.py similarity index 68% rename from tests/signal/test_signal_data_service.py rename to tests/signal/test_signal_stats_service.py index 914594dbd782..34b4261da779 100644 --- a/tests/signal/test_signal_data_service.py +++ b/tests/signal/test_signal_stats_service.py @@ -161,3 +161,78 @@ def test_get_signal_stats_not_found(session, entity_type): assert signal_data.num_signal_instances_snoozed == 0 assert signal_data.num_snoozes_active == 0 assert signal_data.num_snoozes_expired == 0 + + +def test_get_signal_stats_with_signal_id_filter(session, entity, entity_type): + """Test get_signal_stats with a specific signal ID filter.""" + from dispatch.signal.service import get_signal_stats + from dispatch.signal.models import Signal, SignalInstance + + # Setup: Create two different signals + signal1 = Signal(name="Test Signal 1", variant="test-signal-1", project_id=1) + signal2 = Signal(name="Test Signal 2", variant="test-signal-2", project_id=1) + session.add_all([signal1, signal2]) + session.flush() + + # Associate entity with entity type + entity.entity_type = entity_type + + # Create a signal instance for signal1 + signal_instance1 = SignalInstance(signal=signal1, project_id=1) + signal_instance1.entities.append(entity) + + # Create two signal instances for signal2 + signal_instance2 = SignalInstance(signal=signal2, project_id=1) + signal_instance2.entities.append(entity) + signal_instance3 = SignalInstance(signal=signal2, project_id=1) + signal_instance3.entities.append(entity) + + session.add_all([signal_instance1, signal_instance2, signal_instance3]) + session.commit() + + # Execute: Call the service function without signal_id (should count both signals) + signal_data_all = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + num_days=None, + ) + + # Execute: Call the service function with signal_id for signal1 + signal_data_signal1 = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + signal_id=signal1.id, + num_days=None, + ) + + # Execute: Call the service function with signal_id for signal2 + signal_data_signal2 = get_signal_stats( + db_session=session, + entity_value=entity.value, + entity_type_id=entity_type.id, + signal_id=signal2.id, + num_days=None, + ) + + # Assert: Without signal_id filter, we should count both instances + assert ( + signal_data_all.num_signal_instances_alerted + + signal_data_all.num_signal_instances_snoozed + == 3 + ) + + # Assert: With signal1 filter, we should count only signal_instance1 + assert ( + signal_data_signal1.num_signal_instances_alerted + + signal_data_signal1.num_signal_instances_snoozed + == 1 + ) + + # Assert: With signal2 filter, we should count only signal_instance2 + assert ( + signal_data_signal2.num_signal_instances_alerted + + signal_data_signal2.num_signal_instances_snoozed + == 2 + )