Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/artifacts/errors/actor.not_found.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 27 additions & 32 deletions engine/packages/actor-kv/tests/list_edge_cases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ async fn test_list_edge_cases() -> Result<()> {
None,
)
.await?;
assert_eq!(no_match.len(), 0, "should return empty for non-matching prefix");
assert_eq!(
no_match.len(),
0,
"should return empty for non-matching prefix"
);

// Test 3: Range where start > end (should return empty)
tracing::info!("test 3: range where start > end");
Expand Down Expand Up @@ -137,8 +141,7 @@ async fn test_list_edge_cases() -> Result<()> {
)
.await?;

let (null_keys, null_values, _) =
kv::get(db, actor_id, vec![null_key.clone()]).await?;
let (null_keys, null_values, _) = kv::get(db, actor_id, vec![null_key.clone()]).await?;
assert_eq!(null_keys.len(), 1, "should retrieve key with null byte");
assert_eq!(null_values[0], b"null_value");

Expand Down Expand Up @@ -191,9 +194,7 @@ async fn test_list_edge_cases() -> Result<()> {
let (empty_prefix, _, _) = kv::list(
db,
actor_id,
rp::KvListQuery::KvListPrefixQuery(rp::KvListPrefixQuery {
key: vec![],
}),
rp::KvListQuery::KvListPrefixQuery(rp::KvListPrefixQuery { key: vec![] }),
false,
None,
)
Expand All @@ -204,13 +205,7 @@ async fn test_list_edge_cases() -> Result<()> {

// Test 8: Prefix longer than any stored key
tracing::info!("test 8: prefix longer than stored keys");
kv::put(
db,
actor_id,
vec![b"ab".to_vec()],
vec![b"val".to_vec()],
)
.await?;
kv::put(db, actor_id, vec![b"ab".to_vec()], vec![b"val".to_vec()]).await?;

let (long_prefix, _, _) = kv::list(
db,
Expand Down Expand Up @@ -295,14 +290,26 @@ async fn test_list_edge_cases() -> Result<()> {
)
.await?;

let (zero_limit, _, _) =
kv::list(db, actor_id, rp::KvListQuery::KvListAllQuery, false, Some(0)).await?;
let (zero_limit, _, _) = kv::list(
db,
actor_id,
rp::KvListQuery::KvListAllQuery,
false,
Some(0),
)
.await?;
assert_eq!(zero_limit.len(), 0, "limit of 0 should return empty");

// Test 11: Limit of 1
tracing::info!("test 11: limit of 1");
let (one_limit, _, _) =
kv::list(db, actor_id, rp::KvListQuery::KvListAllQuery, false, Some(1)).await?;
let (one_limit, _, _) = kv::list(
db,
actor_id,
rp::KvListQuery::KvListAllQuery,
false,
Some(1),
)
.await?;
assert_eq!(one_limit.len(), 1, "limit of 1 should return 1 key");

// Test 12: Limit larger than total keys
Expand All @@ -328,18 +335,8 @@ async fn test_list_edge_cases() -> Result<()> {
kv::put(
db,
actor_id,
vec![
b"a".to_vec(),
b"b".to_vec(),
b"c".to_vec(),
b"d".to_vec(),
],
vec![
b"1".to_vec(),
b"2".to_vec(),
b"3".to_vec(),
b"4".to_vec(),
],
vec![b"a".to_vec(), b"b".to_vec(), b"c".to_vec(), b"d".to_vec()],
vec![b"1".to_vec(), b"2".to_vec(), b"3".to_vec(), b"4".to_vec()],
)
.await?;

Expand All @@ -359,9 +356,7 @@ async fn test_list_edge_cases() -> Result<()> {
let (prefix_reverse, _, _) = kv::list(
db,
actor_id,
rp::KvListQuery::KvListPrefixQuery(rp::KvListPrefixQuery {
key: vec![],
}),
rp::KvListQuery::KvListPrefixQuery(rp::KvListPrefixQuery { key: vec![] }),
true,
None,
)
Expand Down
22 changes: 0 additions & 22 deletions engine/packages/guard/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,6 @@ pub struct WrongAddrProtocol {
pub received: &'static str,
}

#[derive(RivetError, Serialize)]
#[error(
"guard",
"actor_not_found",
"Actor not found.",
"Actor with ID {actor_id} not found."
)]
pub struct ActorNotFound {
pub actor_id: Id,
}

#[derive(RivetError, Serialize)]
#[error(
"guard",
"actor_destroyed",
"Actor destroyed.",
"Actor {actor_id} was destroyed."
)]
pub struct ActorDestroyed {
pub actor_id: Id,
}

#[derive(RivetError, Serialize)]
#[error(
"guard",
Expand Down
4 changes: 2 additions & 2 deletions engine/packages/guard/src/routing/pegboard_gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,11 @@ async fn route_request_inner(
.op(pegboard::ops::actor::get_for_gateway::Input { actor_id })
.await?
else {
return Err(errors::ActorNotFound { actor_id }.build());
return Err(pegboard::errors::Actor::NotFound.build());
};

if actor.destroyed {
return Err(errors::ActorDestroyed { actor_id }.build());
return Err(pegboard::errors::Actor::NotFound.build());
}

// Wake actor if sleeping
Expand Down
14 changes: 12 additions & 2 deletions engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ use crate::shared_state::{InFlightRequestHandle, SharedState};

pub mod shared_state;

const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(2);
const WEBSOCKET_OPEN_TIMEOUT: Duration = Duration::from_secs(15);
const TUNNEL_ACK_TIMEOUT: Duration = Duration::from_secs(5);

#[derive(RivetError, Serialize, Deserialize)]
#[error(
Expand Down Expand Up @@ -230,7 +231,7 @@ impl CustomServeTrait for PegboardGateway {

Err(ServiceUnavailable.build())
};
let response_start = tokio::time::timeout(TUNNEL_ACK_TIMEOUT, fut)
let response_start = tokio::time::timeout(WEBSOCKET_OPEN_TIMEOUT, fut)
.await
.map_err(|_| {
tracing::warn!("timed out waiting for tunnel ack");
Expand Down Expand Up @@ -412,6 +413,11 @@ impl CustomServeTrait for PegboardGateway {
client_ws.send(msg).await?;
}
protocol::ToServerTunnelMessageKind::ToServerWebSocketMessageAck(ack) => {
tracing::debug!(
request_id=?Uuid::from_bytes(request_id),
ack_index=?ack.index,
"received WebSocketMessageAck from runner"
);
shared_state
.ack_pending_websocket_messages(request_id, ack.index)
.await?;
Expand Down Expand Up @@ -617,6 +623,10 @@ impl CustomServeTrait for PegboardGateway {
.has_pending_websocket_messages(unique_request_id.into_bytes())
.await?
{
tracing::debug!(
?unique_request_id,
"detected pending requests on websocket hibernation, rewaking actor"
);
return Ok(HibernationResult::Continue);
}

Expand Down
74 changes: 58 additions & 16 deletions engine/packages/pegboard-gateway/src/shared_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

const GC_INTERVAL: Duration = Duration::from_secs(15);
const MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(30);
const MAX_PENDING_MSGS_SIZE_PER_REQ: u64 = util::size::mebibytes(1);
const HWS_MESSAGE_ACK_TIMEOUT: Duration = Duration::from_secs(30);
const HWS_MAX_PENDING_MSGS_SIZE_PER_REQ: u64 = util::size::mebibytes(1);

pub struct InFlightRequestHandle {
pub msg_rx: mpsc::Receiver<protocol::ToServerTunnelMessageKind>,
Expand Down Expand Up @@ -185,7 +186,7 @@
if let (Some(hs), Some(ws_msg_index)) = (&mut req.hibernation_state, ws_msg_index) {
hs.total_pending_ws_msgs_size += message_serialized.len() as u64;

if hs.total_pending_ws_msgs_size > MAX_PENDING_MSGS_SIZE_PER_REQ
if hs.total_pending_ws_msgs_size > HWS_MAX_PENDING_MSGS_SIZE_PER_REQ
|| hs.pending_ws_msgs.len() >= u16::MAX as usize
{
return Err(WebsocketPendingLimitReached {}.build());
Expand Down Expand Up @@ -230,40 +231,54 @@
let Some(mut in_flight) =
self.in_flight_requests.get_async(&msg.request_id).await
else {
tracing::debug!(
tracing::warn!(
request_id=?Uuid::from_bytes(msg.request_id),
"in flight has already been disconnected"
message_id=?Uuid::from_bytes(msg.message_id),
"in flight has already been disconnected, cannot ack message"
);
continue;
};

if let protocol::ToServerTunnelMessageKind::TunnelAck = &msg.message_kind {
let prev_len = in_flight.pending_msgs.len();

tracing::debug!(message_id=?Uuid::from_bytes(msg.message_id), "received tunnel ack");

in_flight
.pending_msgs
.retain(|m| m.message_id != msg.message_id);

if prev_len == in_flight.pending_msgs.len() {
tracing::warn!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"pending message does not exist or ack received after message body"
)
} else {
tracing::debug!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"received TunnelAck, removed from pending"
);
}
} else {
// Send message to the request handler to emulate the real network action
tracing::debug!(
request_id=?Uuid::from_bytes(msg.request_id),
message_id=?Uuid::from_bytes(msg.message_id),
"forwarding message to request handler"
);
let _ = in_flight.msg_tx.send(msg.message_kind).await;
let _ = in_flight.msg_tx.send(msg.message_kind.clone()).await;

// Send ack back to runner
let ups_clone = self.ups.clone();
let receiver_subject = in_flight.receiver_subject.clone();
let request_id = msg.request_id;
let message_id = msg.message_id;
let ack_message = protocol::ToClient::ToClientTunnelMessage(
protocol::ToClientTunnelMessage {
request_id: msg.request_id,
message_id: msg.message_id,
request_id,
message_id,
gateway_reply_to: None,
message_kind: protocol::ToClientTunnelMessageKind::TunnelAck,
},
Expand All @@ -279,15 +294,29 @@
}
};
tokio::spawn(async move {
if let Err(err) = ups_clone
match ups_clone
.publish(
&receiver_subject,
&ack_message_serialized,
PublishOpts::one(),
)
.await
{
tracing::warn!(?err, "failed to ack message")
Ok(_) => {
tracing::debug!(
request_id=?Uuid::from_bytes(request_id),
message_id=?Uuid::from_bytes(message_id),
"sent TunnelAck to runner"
);
}
Err(err) => {
tracing::warn!(
?err,
request_id=?Uuid::from_bytes(request_id),
message_id=?Uuid::from_bytes(message_id),
"failed to send TunnelAck to runner"
);
}
}
});
}
Expand Down Expand Up @@ -366,11 +395,15 @@
};

let Some(hs) = &mut req.hibernation_state else {
tracing::warn!("cannot ack ws messages, hibernation is not enabled");
tracing::warn!(
request_id=?Uuid::from_bytes(request_id),
"cannot ack ws messages, hibernation is not enabled"
);
return Ok(());
};

let len = hs.pending_ws_msgs.len().try_into()?;
let len_before = hs.pending_ws_msgs.len();
let len = len_before.try_into()?;
let mut iter_index = 0u16;
hs.pending_ws_msgs.retain(|_| {
let msg_index = hs
Expand All @@ -385,6 +418,15 @@
keep
});

let len_after = hs.pending_ws_msgs.len();
tracing::debug!(
request_id=?Uuid::from_bytes(request_id),
ack_index,
removed_count=len_before - len_after,
remaining_count=len_after,
"acked pending websocket messages"
);

Ok(())
}

Expand Down Expand Up @@ -425,9 +467,9 @@
/// Gateway channel is closed and there are no pending messages
GatewayClosed,
/// Any tunnel message not acked (TunnelAck)
MessageNotAcked,
MessageNotAcked { message_id: Uuid },

Check failure on line 470 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `message_id` is never read

Check failure on line 470 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `message_id` is never read

Check failure on line 470 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `message_id` is never read
/// WebSocket pending messages (ToServerWebSocketMessageAck)
WebSocketMessageNotAcked,
WebSocketMessageNotAcked { last_ws_msg_index: u16 },

Check failure on line 472 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `last_ws_msg_index` is never read

Check failure on line 472 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `last_ws_msg_index` is never read

Check failure on line 472 in engine/packages/pegboard-gateway/src/shared_state.rs

View workflow job for this annotation

GitHub Actions / Check

field `last_ws_msg_index` is never read
}

let now = Instant::now();
Expand Down Expand Up @@ -460,17 +502,17 @@
if now.duration_since(earliest_pending_msg.send_instant)
<= MESSAGE_ACK_TIMEOUT
{
break 'reason Some(MsgGcReason::MessageNotAcked);
break 'reason Some(MsgGcReason::MessageNotAcked{message_id:Uuid::from_bytes(earliest_pending_msg.message_id)});
}
}

if let Some(hs) = &req.hibernation_state
&& let Some(earliest_pending_ws_msg) = hs.pending_ws_msgs.first()
{
if now.duration_since(earliest_pending_ws_msg.send_instant)
<= MESSAGE_ACK_TIMEOUT
<= HWS_MESSAGE_ACK_TIMEOUT
{
break 'reason Some(MsgGcReason::WebSocketMessageNotAcked);
break 'reason Some(MsgGcReason::WebSocketMessageNotAcked{last_ws_msg_index: hs.last_ws_msg_index});
}
}

Expand Down
2 changes: 1 addition & 1 deletion engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ async fn handle_tunnel_message(
if let Some(req) = active_requests.get(&request_id) {
req.gateway_reply_to.clone()
} else {
tracing::warn!("no active request for tunnel message, may have timed out");
tracing::warn!(request_id=?Uuid::from_bytes(msg.request_id), message_id=?Uuid::from_bytes(msg.message_id), "no active request for tunnel message, may have timed out");
return Ok(());
}
};
Expand Down
Loading
Loading