@@ -14,7 +14,7 @@ use rivet_guard_core::{
1414 ServiceUnavailable , WebSocketServiceHibernate , WebSocketServiceTimeout ,
1515 WebSocketServiceUnavailable ,
1616 } ,
17- proxy_service:: ResponseBody ,
17+ proxy_service:: { ResponseBody , is_ws_hibernate } ,
1818 request_context:: RequestContext ,
1919 websocket_handle:: WebSocketReceiver ,
2020} ;
@@ -152,6 +152,11 @@ impl CustomServeTrait for PegboardGateway {
152152 . context ( "failed to read body" ) ?
153153 . to_bytes ( ) ;
154154
155+ let mut stopped_sub = self
156+ . ctx
157+ . subscribe :: < pegboard:: workflows:: actor:: Stopped > ( ( "actor_id" , self . actor_id ) )
158+ . await ?;
159+
155160 // Build subject to publish to
156161 let tunnel_subject =
157162 pegboard:: pubsub_subjects:: RunnerReceiverSubject :: new ( self . runner_id ) . to_string ( ) ;
@@ -212,6 +217,10 @@ impl CustomServeTrait for PegboardGateway {
212217 break ;
213218 }
214219 }
220+ _ = stopped_sub. next( ) => {
221+ tracing:: debug!( "actor stopped while waiting for request response" ) ;
222+ return Err ( ServiceUnavailable . build( ) ) ;
223+ }
215224 _ = drop_rx. changed( ) => {
216225 tracing:: warn!( "tunnel message timeout" ) ;
217226 return Err ( ServiceUnavailable . build( ) ) ;
@@ -278,6 +287,11 @@ impl CustomServeTrait for PegboardGateway {
278287 }
279288 }
280289
290+ let mut stopped_sub = self
291+ . ctx
292+ . subscribe :: < pegboard:: workflows:: actor:: Stopped > ( ( "actor_id" , self . actor_id ) )
293+ . await ?;
294+
281295 // Build subject to publish to
282296 let tunnel_subject =
283297 pegboard:: pubsub_subjects:: RunnerReceiverSubject :: new ( self . runner_id ) . to_string ( ) ;
@@ -339,6 +353,10 @@ impl CustomServeTrait for PegboardGateway {
339353 break ;
340354 }
341355 }
356+ _ = stopped_sub. next( ) => {
357+ tracing:: debug!( "actor stopped while waiting for websocket open" ) ;
358+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
359+ }
342360 _ = drop_rx. changed( ) => {
343361 tracing:: warn!( "websocket open timeout" ) ;
344362 return Err ( WebSocketServiceUnavailable . build( ) ) ;
@@ -364,7 +382,7 @@ impl CustomServeTrait for PegboardGateway {
364382 open_msg. can_hibernate
365383 } ;
366384
367- // Send reclaimed messages
385+ // Send pending messages
368386 self . shared_state
369387 . resend_pending_websocket_messages ( request_id)
370388 . await ?;
@@ -415,6 +433,15 @@ impl CustomServeTrait for PegboardGateway {
415433 return Err ( WebSocketServiceHibernate . build( ) ) ;
416434 }
417435 }
436+ _ = stopped_sub. next( ) => {
437+ tracing:: debug!( "actor stopped during websocket handler loop" ) ;
438+
439+ if can_hibernate {
440+ return Err ( WebSocketServiceHibernate . build( ) ) ;
441+ } else {
442+ return Err ( WebSocketServiceUnavailable . build( ) ) ;
443+ }
444+ }
418445 _ = drop_rx. changed( ) => {
419446 tracing:: warn!( "websocket message timeout" ) ;
420447 return Err ( WebSocketServiceTimeout . build( ) ) ;
@@ -532,28 +559,33 @@ impl CustomServeTrait for PegboardGateway {
532559 ( res, _) => res,
533560 } ;
534561
535- // Send WebSocket close message to runner
536- let ( close_code, close_reason) = match & mut lifecycle_res {
537- // Taking here because it won't be used again
538- Ok ( LifecycleResult :: ClientClose ( Some ( close) ) ) => {
539- ( close. code , Some ( std:: mem:: take ( & mut close. reason ) ) )
540- }
541- Ok ( _) => ( CloseCode :: Normal . into ( ) , None ) ,
542- Err ( _) => ( CloseCode :: Error . into ( ) , Some ( "ws.downstream_closed" . into ( ) ) ) ,
543- } ;
544- let close_message = protocol:: ToClientTunnelMessageKind :: ToClientWebSocketClose (
545- protocol:: ToClientWebSocketClose {
546- code : Some ( close_code. into ( ) ) ,
547- reason : close_reason. map ( |x| x. as_str ( ) . to_string ( ) ) ,
548- } ,
549- ) ;
550-
551- if let Err ( err) = self
552- . shared_state
553- . send_message ( request_id, close_message)
554- . await
562+ // Send close frame to runner if not hibernating
563+ if lifecycle_res
564+ . as_ref ( )
565+ . map_or_else ( is_ws_hibernate, |_| false )
555566 {
556- tracing:: error!( ?err, "error sending close message" ) ;
567+ let ( close_code, close_reason) = match & mut lifecycle_res {
568+ // Taking here because it won't be used again
569+ Ok ( LifecycleResult :: ClientClose ( Some ( close) ) ) => {
570+ ( close. code , Some ( std:: mem:: take ( & mut close. reason ) ) )
571+ }
572+ Ok ( _) => ( CloseCode :: Normal . into ( ) , None ) ,
573+ Err ( _) => ( CloseCode :: Error . into ( ) , Some ( "ws.downstream_closed" . into ( ) ) ) ,
574+ } ;
575+ let close_message = protocol:: ToClientTunnelMessageKind :: ToClientWebSocketClose (
576+ protocol:: ToClientWebSocketClose {
577+ code : Some ( close_code. into ( ) ) ,
578+ reason : close_reason. map ( |x| x. as_str ( ) . to_string ( ) ) ,
579+ } ,
580+ ) ;
581+
582+ if let Err ( err) = self
583+ . shared_state
584+ . send_message ( request_id, close_message)
585+ . await
586+ {
587+ tracing:: error!( ?err, "error sending close message" ) ;
588+ }
557589 }
558590
559591 // Send WebSocket close message to client
@@ -579,6 +611,15 @@ impl CustomServeTrait for PegboardGateway {
579611 client_ws : WebSocketHandle ,
580612 unique_request_id : Uuid ,
581613 ) -> Result < HibernationResult > {
614+ // Immediately rewake if we have pending messages
615+ if self
616+ . shared_state
617+ . has_pending_websocket_messages ( unique_request_id. into_bytes ( ) )
618+ . await ?
619+ {
620+ return Ok ( HibernationResult :: Continue ) ;
621+ }
622+
582623 // Start keepalive task
583624 let ctx = self . ctx . clone ( ) ;
584625 let actor_id = self . actor_id ;
0 commit comments