Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions engine/packages/pegboard-gateway/src/keepalive_task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use anyhow::Result;
use gas::prelude::*;
use pegboard::tunnel::id::{GatewayId, RequestId};
use rand::Rng;
use std::time::Duration;
use tokio::sync::watch;

use super::LifecycleResult;

/// Periodically pings writes keepalive in UDB. This is used to restore hibernating request IDs on
/// next actor start.
///
///Only ran for hibernating requests.
pub async fn task(
ctx: StandaloneCtx,
actor_id: Id,
gateway_id: GatewayId,
request_id: RequestId,
mut keepalive_abort_rx: watch::Receiver<()>,
) -> Result<LifecycleResult> {
let mut ping_interval = tokio::time::interval(Duration::from_millis(
(ctx.config()
.pegboard()
.hibernating_request_eligible_threshold()
/ 2)
.try_into()?,
));
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

// Discard the first tick since it fires immediately and we've already called this
// above
ping_interval.tick().await;

loop {
tokio::select! {
_ = ping_interval.tick() => {}
_ = keepalive_abort_rx.changed() => {
return Ok(LifecycleResult::Aborted);
}
}

// Jitter sleep to prevent stampeding herds
let jitter = { rand::thread_rng().gen_range(0..128) };
tokio::time::sleep(Duration::from_millis(jitter)).await;

ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id,
gateway_id,
request_id,
})
.await?;
}
}
117 changes: 65 additions & 52 deletions engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,27 @@ use gas::prelude::*;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use pegboard::tunnel::id::{self as tunnel_id, RequestId};
use rand::Rng;
use rivet_error::*;
use rivet_guard_core::{
WebSocketHandle,
custom_serve::{CustomServeTrait, HibernationResult},
errors::{ServiceUnavailable, WebSocketServiceUnavailable},
proxy_service::{is_ws_hibernate, ResponseBody},
proxy_service::{ResponseBody, is_ws_hibernate},
request_context::RequestContext,
websocket_handle::WebSocketReceiver,
WebSocketHandle,
};
use rivet_runner_protocol as protocol;
use rivet_util::serde::HashableMap;
use std::{sync::Arc, time::Duration};
use tokio::{
sync::{watch, Mutex},
task::JoinHandle,
};
use tokio::sync::{Mutex, watch};
use tokio_tungstenite::tungstenite::{
protocol::frame::{coding::CloseCode, CloseFrame},
Message,
protocol::frame::{CloseFrame, coding::CloseCode},
};

use crate::shared_state::{InFlightRequestHandle, SharedState};

mod keepalive_task;
mod metrics;
mod ping_task;
pub mod shared_state;
Expand Down Expand Up @@ -396,6 +393,7 @@ impl CustomServeTrait for PegboardGateway {
let (tunnel_to_ws_abort_tx, tunnel_to_ws_abort_rx) = watch::channel(());
let (ws_to_tunnel_abort_tx, ws_to_tunnel_abort_rx) = watch::channel(());
let (ping_abort_tx, ping_abort_rx) = watch::channel(());
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());

let tunnel_to_ws = tokio::spawn(tunnel_to_ws_task::task(
self.shared_state.clone(),
Expand Down Expand Up @@ -423,8 +421,14 @@ impl CustomServeTrait for PegboardGateway {
let ws_to_tunnel_abort_tx2 = ws_to_tunnel_abort_tx.clone();
let ping_abort_tx2 = ping_abort_tx.clone();

// Wait for both tasks to complete
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) = tokio::join!(
// Clone variables needed for keepalive task
let ctx_clone = self.ctx.clone();
let actor_id_clone = self.actor_id;
let gateway_id_clone = self.shared_state.gateway_id();
let request_id_clone = request_id;

// Wait for all tasks to complete
let (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res) = tokio::join!(
async {
let res = tunnel_to_ws.await?;

Expand All @@ -434,6 +438,7 @@ impl CustomServeTrait for PegboardGateway {

let _ = ping_abort_tx.send(());
let _ = ws_to_tunnel_abort_tx.send(());
let _ = keepalive_abort_tx.send(());
} else {
tracing::debug!(?res, "tunnel to ws task completed");
}
Expand All @@ -449,6 +454,7 @@ impl CustomServeTrait for PegboardGateway {

let _ = ping_abort_tx2.send(());
let _ = tunnel_to_ws_abort_tx.send(());
let _ = keepalive_abort_tx.send(());
} else {
tracing::debug!(?res, "ws to tunnel task completed");
}
Expand All @@ -464,25 +470,56 @@ impl CustomServeTrait for PegboardGateway {

let _ = ws_to_tunnel_abort_tx2.send(());
let _ = tunnel_to_ws_abort_tx2.send(());
let _ = keepalive_abort_tx.send(());
} else {
tracing::debug!(?res, "ping task completed");
}

res
},
async {
if !can_hibernate {
return Ok(LifecycleResult::Aborted);
}

let keepalive = tokio::spawn(keepalive_task::task(
ctx_clone,
actor_id_clone,
gateway_id_clone,
request_id_clone,
keepalive_abort_rx,
));

let res = keepalive.await?;

// Abort others if not aborted
if !matches!(res, Ok(LifecycleResult::Aborted)) {
tracing::debug!(?res, "keepalive task completed, aborting others");

let _ = ws_to_tunnel_abort_tx2.send(());
let _ = tunnel_to_ws_abort_tx2.send(());
let _ = ping_abort_tx2.send(());
} else {
tracing::debug!(?res, "keepalive task completed");
}

res
},
);

// Determine single result from all tasks
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res) {
let mut lifecycle_res = match (tunnel_to_ws_res, ws_to_tunnel_res, ping_res, keepalive_res)
{
// Prefer error
(Err(err), _, _) => Err(err),
(_, Err(err), _) => Err(err),
(_, _, Err(err)) => Err(err),
// Prefer non aborted result if both succeed
(Ok(res), Ok(LifecycleResult::Aborted), _) => Ok(res),
(Ok(LifecycleResult::Aborted), Ok(res), _) => Ok(res),
(Err(err), _, _, _) => Err(err),
(_, Err(err), _, _) => Err(err),
(_, _, Err(err), _) => Err(err),
(_, _, _, Err(err)) => Err(err),
// Prefer non aborted result if all succeed
(Ok(res), Ok(LifecycleResult::Aborted), _, _) => Ok(res),
(Ok(LifecycleResult::Aborted), Ok(res), _, _) => Ok(res),
// Unlikely case
(res, _, _) => res,
(res, _, _, _) => res,
};

// Send close frame to runner if not hibernating
Expand Down Expand Up @@ -564,43 +601,19 @@ impl CustomServeTrait for PegboardGateway {
}

// Start keepalive task
let ctx = self.ctx.clone();
let actor_id = self.actor_id;
let gateway_id = self.shared_state.gateway_id();
let request_id = unique_request_id;
let keepalive_handle: JoinHandle<Result<()>> = tokio::spawn(async move {
let mut ping_interval = tokio::time::interval(Duration::from_millis(
(ctx.config()
.pegboard()
.hibernating_request_eligible_threshold()
/ 2)
.try_into()?,
));
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);

// Discard the first tick since it fires immediately and we've already called this
// above
ping_interval.tick().await;

loop {
ping_interval.tick().await;

// Jitter sleep to prevent stampeding herds
let jitter = { rand::thread_rng().gen_range(0..128) };
tokio::time::sleep(Duration::from_millis(jitter)).await;

ctx.op(pegboard::ops::actor::hibernating_request::upsert::Input {
actor_id,
gateway_id,
request_id,
})
.await?;
}
});
let (keepalive_abort_tx, keepalive_abort_rx) = watch::channel(());
let keepalive_handle = tokio::spawn(keepalive_task::task(
self.ctx.clone(),
self.actor_id,
self.shared_state.gateway_id(),
unique_request_id,
keepalive_abort_rx,
));

let res = self.handle_websocket_hibernation_inner(client_ws).await;

keepalive_handle.abort();
let _ = keepalive_abort_tx.send(());
let _ = keepalive_handle.await;

match &res {
Ok(HibernationResult::Continue) => {}
Expand Down
Loading