diff --git a/crates/chat-cli/src/api_client/model.rs b/crates/chat-cli/src/api_client/model.rs index 1a72023f60..ea792b4cec 100644 --- a/crates/chat-cli/src/api_client/model.rs +++ b/crates/chat-cli/src/api_client/model.rs @@ -569,6 +569,9 @@ pub enum ChatResponseStream { conversation_id: Option, utterance_id: Option, }, + MetadataEvent { + usage: Option, + }, SupplementaryWebLinksEvent(()), ToolUseEvent { tool_use_id: String, @@ -581,6 +584,12 @@ pub enum ChatResponseStream { Unknown, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MetadataUsage { + pub input_tokens: Option, + pub output_tokens: Option, +} + impl ChatResponseStream { /// Returns the length of the content of the message event - ie, the number of bytes of content /// contained within the message. @@ -596,6 +605,7 @@ impl ChatResponseStream { ChatResponseStream::IntentsEvent(_) => 0, ChatResponseStream::InvalidStateEvent { .. } => 0, ChatResponseStream::MessageMetadataEvent { .. } => 0, + ChatResponseStream::MetadataEvent { .. } => 0, ChatResponseStream::SupplementaryWebLinksEvent(_) => 0, ChatResponseStream::ToolUseEvent { input, .. } => input.as_ref().map(|s| s.len()).unwrap_or_default(), ChatResponseStream::Unknown => 0, @@ -642,6 +652,14 @@ impl From for Ch conversation_id, utterance_id, }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::MetadataEvent(metadata) => { + ChatResponseStream::MetadataEvent { + usage: metadata.token_usage.map(|u| MetadataUsage { + input_tokens: Some(u.uncached_input_tokens as u64), + output_tokens: Some(u.output_tokens as u64), + }), + } + }, amzn_codewhisperer_streaming_client::types::ChatResponseStream::ToolUseEvent( amzn_codewhisperer_streaming_client::types::ToolUseEvent { tool_use_id, @@ -698,6 +716,14 @@ impl From for ChatR conversation_id, utterance_id, }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::MetadataEvent(metadata) => { + ChatResponseStream::MetadataEvent { + usage: metadata.token_usage.map(|u| MetadataUsage { + input_tokens: Some(u.uncached_input_tokens as u64), + output_tokens: Some(u.output_tokens as u64), + }), + } + }, amzn_qdeveloper_streaming_client::types::ChatResponseStream::ToolUseEvent( amzn_qdeveloper_streaming_client::types::ToolUseEvent { tool_use_id, diff --git a/crates/chat-cli/src/cli/chat/cli/clear.rs b/crates/chat-cli/src/cli/chat/cli/clear.rs index b706bbbdb8..39fdbec65a 100644 --- a/crates/chat-cli/src/cli/chat/cli/clear.rs +++ b/crates/chat-cli/src/cli/chat/cli/clear.rs @@ -13,7 +13,6 @@ use crate::cli::chat::{ ChatSession, ChatState, }; -use crate::theme::StyledText; #[deny(missing_docs)] #[derive(Debug, PartialEq, Args)] @@ -24,20 +23,20 @@ impl ClearArgs { pub async fn execute(self, session: &mut ChatSession) -> Result { execute!( session.stderr, - StyledText::secondary_fg(), + style::SetForegroundColor(style::Color::DarkGrey), style::Print( "\nAre you sure? This will erase the conversation history and context from hooks for the current session. " ), style::Print("["), - StyledText::success_fg(), + style::SetForegroundColor(style::Color::Green), style::Print("y"), - StyledText::secondary_fg(), + style::SetForegroundColor(style::Color::DarkGrey), style::Print("/"), - StyledText::success_fg(), + style::SetForegroundColor(style::Color::Green), style::Print("n"), - StyledText::secondary_fg(), + style::SetForegroundColor(style::Color::DarkGrey), style::Print("]:\n\n"), - StyledText::reset(), + style::ResetColor, cursor::Show, )?; @@ -60,12 +59,61 @@ impl ClearArgs { execute!( session.stderr, - StyledText::success_fg(), + style::SetForegroundColor(style::Color::Green), style::Print("\nConversation history cleared.\n\n"), - StyledText::reset(), + style::ResetColor, )?; } Ok(ChatState::default()) } } + +#[cfg(test)] +mod tests { + use crossterm::{ + execute, + style, + }; + + #[test] + fn test_clear_prompt_renders_correctly() { + let mut buffer = Vec::new(); + + // Test the actual implementation pattern used in clear command + let result = execute!( + &mut buffer, + style::SetForegroundColor(style::Color::DarkGrey), + style::Print("Test "), + style::Print("["), + style::SetForegroundColor(style::Color::Green), + style::Print("y"), + style::SetForegroundColor(style::Color::DarkGrey), + style::Print("/"), + style::SetForegroundColor(style::Color::Green), + style::Print("n"), + style::SetForegroundColor(style::Color::DarkGrey), + style::Print("]"), + style::ResetColor, + ); + + assert!(result.is_ok()); + + let output = String::from_utf8(buffer).unwrap(); + eprintln!("Output: {:?}", output); + + // Verify the text content is correct + assert!(output.contains("Test"), "Output should contain 'Test'"); + assert!(output.contains("["), "Output should contain '['"); + assert!(output.contains("y"), "Output should contain 'y'"); + assert!(output.contains("/"), "Output should contain '/'"); + assert!(output.contains("n"), "Output should contain 'n'"); + assert!(output.contains("]"), "Output should contain ']'"); + + // Verify ANSI escape sequences are present + assert!(output.contains("\x1b["), "Output should contain ANSI escape sequences"); + + // Verify reset code is present + assert!(output.contains("\x1b[0m"), "Output should contain reset code"); + } +} diff --git a/crates/chat-cli/src/cli/chat/cli/hooks.rs b/crates/chat-cli/src/cli/chat/cli/hooks.rs index fa83288c4c..e1a852de31 100644 --- a/crates/chat-cli/src/cli/chat/cli/hooks.rs +++ b/crates/chat-cli/src/cli/chat/cli/hooks.rs @@ -18,6 +18,7 @@ use crossterm::{ queue, terminal, }; +use tracing::warn; use eyre::{ Result, eyre, @@ -123,6 +124,7 @@ impl HookExecutor { cwd: &str, prompt: Option<&str>, tool_context: Option, + mut ctrlc_rx: tokio::sync::broadcast::Receiver<()>, ) -> Result, ChatError> { let mut cached = vec![]; let mut futures = FuturesUnordered::new(); @@ -163,7 +165,10 @@ impl HookExecutor { // Process results as they complete let mut results = vec![]; let start_time = Instant::now(); - while let Some((hook, result, duration)) = futures.next().await { + + tokio::select! { + res = async { + while let Some((hook, result, duration)) = futures.next().await { // If output is enabled, handle that first if let Some(spinner) = spinner.as_mut() { spinner.stop(); @@ -239,6 +244,13 @@ impl HookExecutor { } else { spinner = Some(Spinner::new(Spinners::Dots, spinner_text(complete, total))); } + } + Ok::<(), std::io::Error>(()) + } => { res?; }, + Ok(_) = ctrlc_rx.recv() => { + warn!("🔴 CTRL+C caught in run_hooks, cancelling hook execution"); + return Err(ChatError::Interrupted { tool_uses: None }); + } } drop(futures); diff --git a/crates/chat-cli/src/cli/chat/cli/usage/usage_data_provider.rs b/crates/chat-cli/src/cli/chat/cli/usage/usage_data_provider.rs index aec77e6d8f..e83982ec0f 100644 --- a/crates/chat-cli/src/cli/chat/cli/usage/usage_data_provider.rs +++ b/crates/chat-cli/src/cli/chat/cli/usage/usage_data_provider.rs @@ -18,7 +18,7 @@ pub(super) async fn get_detailed_usage_data( let state = session .conversation - .backend_conversation_state(os, true, &mut std::io::stderr()) + .backend_conversation_state(os, true, &mut std::io::stderr(), tokio::sync::broadcast::channel(1).1) .await?; let data = state.calculate_conversation_size(); diff --git a/crates/chat-cli/src/cli/chat/context.rs b/crates/chat-cli/src/cli/chat/context.rs index 6cb760ac34..57734a656e 100644 --- a/crates/chat-cli/src/cli/chat/context.rs +++ b/crates/chat-cli/src/cli/chat/context.rs @@ -250,12 +250,13 @@ impl ContextManager { os: &crate::os::Os, prompt: Option<&str>, tool_context: Option, + ctrlc_rx: tokio::sync::broadcast::Receiver<()>, ) -> Result, ChatError> { let mut hooks = self.hooks.clone(); hooks.retain(|t, _| *t == trigger); let cwd = os.env.current_dir()?.to_string_lossy().to_string(); self.hook_executor - .run_hooks(hooks, output, &cwd, prompt, tool_context) + .run_hooks(hooks, output, &cwd, prompt, tool_context, ctrlc_rx) .await } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index a9c424900b..3c63e3f948 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -149,6 +149,11 @@ pub struct ConversationState { /// Tangent mode checkpoint - stores main conversation when in tangent mode #[serde(default, skip_serializing_if = "Option::is_none")] tangent_state: Option, + /// Cumulative token usage across the conversation + #[serde(default)] + pub cumulative_input_tokens: u64, + #[serde(default)] + pub cumulative_output_tokens: u64, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -212,6 +217,8 @@ impl ConversationState { checkpoint_manager: None, mcp_enabled, tangent_state: None, + cumulative_input_tokens: 0, + cumulative_output_tokens: 0, } } @@ -410,6 +417,16 @@ impl ConversationState { debug_assert!(self.next_message.is_some(), "next_message should exist"); let next_user_message = self.next_message.take().expect("next user message should exist"); + // Track cumulative token usage + if let Some(ref metadata) = request_metadata { + if let Some(input) = metadata.input_tokens { + self.cumulative_input_tokens += input; + } + if let Some(output) = metadata.output_tokens { + self.cumulative_output_tokens += output; + } + } + self.append_assistant_transcript(&message); self.history.push_back(HistoryEntry { user: next_user_message, @@ -427,6 +444,15 @@ impl ConversationState { self.conversation_id.as_ref() } + /// Checks if cumulative token usage exceeds the threshold percentage. + /// Returns true if compaction should be triggered proactively. + pub fn should_compact_proactively(&self, threshold_percent: u8) -> bool { + let context_window = context_window_tokens(self.model_info.as_ref()); + let cumulative_tokens = self.cumulative_input_tokens + self.cumulative_output_tokens; + let threshold = (context_window as u64 * threshold_percent as u64) / 100; + cumulative_tokens >= threshold + } + /// Returns the message id associated with the last assistant message, if present. /// /// This is equivalent to `utterance_id` in the Q API. @@ -510,13 +536,14 @@ impl ConversationState { os: &Os, stderr: &mut impl Write, run_perprompt_hooks: bool, + ctrlc_rx: tokio::sync::broadcast::Receiver<()>, ) -> Result { debug_assert!(self.next_message.is_some()); self.enforce_conversation_invariants(); self.history.drain(self.valid_history_range.1..); self.history.drain(..self.valid_history_range.0); - let context = self.backend_conversation_state(os, run_perprompt_hooks, stderr).await?; + let context = self.backend_conversation_state(os, run_perprompt_hooks, stderr, ctrlc_rx).await?; if !context.dropped_context_files.is_empty() { execute!( stderr, @@ -572,6 +599,7 @@ impl ConversationState { os: &Os, run_perprompt_hooks: bool, output: &mut impl Write, + ctrlc_rx: tokio::sync::broadcast::Receiver<()>, ) -> Result, ChatError> { self.update_state(false).await; self.enforce_conversation_invariants(); @@ -587,6 +615,7 @@ impl ConversationState { os, user_prompt, None, // tool_context + ctrlc_rx.resubscribe(), ) .await?; agent_spawn_context = format_hook_context(&agent_spawn, HookTrigger::AgentSpawn); @@ -599,6 +628,7 @@ impl ConversationState { os, next_message.prompt(), None, // tool_context + ctrlc_rx.resubscribe(), ) .await?; if let Some(ctx) = format_hook_context(&per_prompt, HookTrigger::UserPromptSubmit) { @@ -691,7 +721,7 @@ impl ConversationState { summary_content.push_str(CONTEXT_ENTRY_END_HEADER); } - let conv_state = self.backend_conversation_state(os, false, &mut vec![]).await?; + let conv_state = self.backend_conversation_state(os, false, &mut vec![], tokio::sync::broadcast::channel(1).1).await?; let mut summary_message = Some(UserMessage::new_prompt(summary_content.clone(), None)); // Create the history according to the passed compact strategy. @@ -872,7 +902,7 @@ Return only the JSON configuration, no additional text.", /// Calculate the total character count in the conversation pub async fn calculate_char_count(&mut self, os: &Os) -> Result { Ok(self - .backend_conversation_state(os, false, &mut vec![]) + .backend_conversation_state(os, false, &mut vec![], tokio::sync::broadcast::channel(1).1) .await? .char_count()) } @@ -1751,4 +1781,222 @@ mod tests { conversation.exit_tangent_mode_with_tail(); assert_eq!(conversation.history.len(), main_history_len); } + + #[test] + fn test_overflow_at_exact_boundary() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + for i in 0..5000 { + history.push_back(HistoryEntry { + user: UserMessage::new_prompt(format!("msg{}", i), None), + assistant: AssistantMessage::new_response(None, format!("resp{}", i)), + request_metadata: None, + }); + } + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let total_messages = (end - start) * 2; + + assert!( + total_messages <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "At 5000 entries (10000 messages), should be at limit but got {}", + total_messages + ); + } + + #[test] + fn test_overflow_past_boundary() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + for i in 0..5001 { + history.push_back(HistoryEntry { + user: UserMessage::new_prompt(format!("msg{}", i), None), + assistant: AssistantMessage::new_response(None, format!("resp{}", i)), + request_metadata: None, + }); + } + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let total_messages = (end - start) * 2; + + println!("History entries: 5001"); + println!("Valid range: {} to {}", start, end); + println!("Valid entries: {}", end - start); + println!("Total messages sent to backend: {}", total_messages); + println!("Limit: {}", MAX_CONVERSATION_STATE_HISTORY_LEN); + + assert!( + total_messages <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "BUG: At 5001 entries, got {} messages (exceeds limit {})", + total_messages, + MAX_CONVERSATION_STATE_HISTORY_LEN + ); + } + + #[test] + fn test_overflow_with_context_buffer() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + for i in 0..4998 { + history.push_back(HistoryEntry { + user: UserMessage::new_prompt(format!("msg{}", i), None), + assistant: AssistantMessage::new_response(None, format!("resp{}", i)), + request_metadata: None, + }); + } + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let total_messages = (end - start) * 2; + let total_with_context = total_messages + 6; + + println!("History entries: 4998"); + println!("Valid range: {} to {}", start, end); + println!("Total messages: {}", total_messages); + println!("Total with 6 context messages: {}", total_with_context); + println!("Limit: {}", MAX_CONVERSATION_STATE_HISTORY_LEN); + + // BUG: This is at EXACTLY the limit with no safety buffer + assert!( + total_with_context < MAX_CONVERSATION_STATE_HISTORY_LEN, + "BUG: At 4998 entries, total with context ({}) should be LESS than limit ({}), not equal. No safety buffer!", + total_with_context, + MAX_CONVERSATION_STATE_HISTORY_LEN + ); + } + + #[test] + fn test_overflow_with_more_context() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + for i in 0..4998 { + history.push_back(HistoryEntry { + user: UserMessage::new_prompt(format!("msg{}", i), None), + assistant: AssistantMessage::new_response(None, format!("resp{}", i)), + request_metadata: None, + }); + } + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let total_messages = (end - start) * 2; + // Real scenarios often have more than 6 context messages (summaries, file context, etc.) + let total_with_realistic_context = total_messages + 10; + + println!("History entries: 4998"); + println!("Total messages: {}", total_messages); + println!("Total with 10 context messages: {}", total_with_realistic_context); + println!("Limit: {}", MAX_CONVERSATION_STATE_HISTORY_LEN); + + assert!( + total_with_realistic_context <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "BUG: With realistic context (10 messages), total {} exceeds limit {}", + total_with_realistic_context, + MAX_CONVERSATION_STATE_HISTORY_LEN + ); + } + + #[test] + fn test_trimming_threshold_detection() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + for i in 0..4997 { + history.push_back(HistoryEntry { + user: UserMessage::new_prompt(format!("msg{}", i), None), + assistant: AssistantMessage::new_response(None, format!("resp{}", i)), + request_metadata: None, + }); + } + + let (start, _) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + println!("At 4997 entries: start={}", start); + assert_eq!(start, 0, "At 4997 entries, no trimming should occur"); + + history.push_back(HistoryEntry { + user: UserMessage::new_prompt("msg4997".to_string(), None), + assistant: AssistantMessage::new_response(None, "resp4997".to_string()), + request_metadata: None, + }); + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + println!("At 4998 entries: start={}, end={}", start, end); + assert!(start > 0, "At 4998 entries, trimming should occur"); + + let total_with_context = ((end - start) * 2) + 6; + println!("Total with context: {}", total_with_context); + + assert!( + total_with_context <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "BUG: After trimming at 4998, total {} exceeds limit {}", + total_with_context, + MAX_CONVERSATION_STATE_HISTORY_LEN + ); + } + + #[test] + fn test_proactive_compaction_threshold_calculation() { + use crate::cli::chat::cli::model::{ModelInfo, context_window_tokens}; + + // Test with 100K context window + let model_100k = ModelInfo { + model_id: "test-model".to_string(), + model_name: Some("Test Model".to_string()), + description: None, + context_window_tokens: 100_000, + }; + + let context_window = context_window_tokens(Some(&model_100k)); + let threshold_98 = (context_window as u64 * 98) / 100; + + assert_eq!(context_window, 100_000); + assert_eq!(threshold_98, 98_000); + + // At 97% - should NOT trigger + let tokens_97 = 97_000u64; + assert!(tokens_97 < threshold_98); + + // At exactly 98% - SHOULD trigger + let tokens_98 = 98_000u64; + assert!(tokens_98 >= threshold_98); + + // At 99% - SHOULD trigger + let tokens_99 = 99_000u64; + assert!(tokens_99 >= threshold_98); + } + + #[test] + fn test_no_overflow_with_98_percent_threshold() { + use crate::cli::chat::cli::model::ModelInfo; + + // Claude Sonnet 4 with 200K context + let model = ModelInfo { + model_id: "claude-sonnet-4".to_string(), + model_name: Some("Claude Sonnet 4".to_string()), + description: None, + context_window_tokens: 200_000, + }; + + // 98% of 200K = 196K + let threshold = (model.context_window_tokens as u64 * 98) / 100; + assert_eq!(threshold, 196_000); + + // Verify 98% threshold leaves 2% buffer (4K tokens) + let buffer = model.context_window_tokens as u64 - threshold; + assert_eq!(buffer, 4_000); + + // This buffer should be enough for: + // - Assistant response tokens + // - Tool spec overhead + // - System messages + // - Estimation errors + assert!(buffer >= 2_000, "Buffer should be at least 2K tokens"); + } } diff --git a/crates/chat-cli/src/cli/chat/conversation_overflow_test.rs b/crates/chat-cli/src/cli/chat/conversation_overflow_test.rs new file mode 100644 index 0000000000..0463893030 --- /dev/null +++ b/crates/chat-cli/src/cli/chat/conversation_overflow_test.rs @@ -0,0 +1,159 @@ +#[cfg(test)] +mod overflow_tests { + use std::collections::{HashMap, VecDeque}; + use crate::api_client::model::Tool; + use crate::cli::chat::consts::MAX_CONVERSATION_STATE_HISTORY_LEN; + use crate::cli::chat::conversation::{HistoryEntry, enforce_conversation_invariants}; + use crate::cli::chat::message::{AssistantMessage, UserMessage}; + use crate::cli::chat::tools::ToolOrigin; + + fn create_history_entry(content: &str) -> HistoryEntry { + HistoryEntry { + user: UserMessage::new_prompt(content.to_string(), None), + assistant: AssistantMessage::new_response(None, content.to_string()), + request_metadata: None, + } + } + + #[test] + fn test_overflow_at_boundary() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + // Create exactly 5000 history entries (10000 messages when flattened) + for i in 0..5000 { + history.push_back(create_history_entry(&format!("msg{}", i))); + } + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let valid_history_len = end - start; + let total_messages = valid_history_len * 2; + + // BUG: This should fail but doesn't - we're at exactly the limit + assert!( + total_messages <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "Expected total messages {} to be <= {}, but overflow detection didn't trigger", + total_messages, + MAX_CONVERSATION_STATE_HISTORY_LEN + ); + } + + #[test] + #[should_panic(expected = "Expected overflow detection to trigger")] + fn test_overflow_one_past_boundary() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + // Create 5001 history entries (10002 messages when flattened) + for i in 0..5001 { + history.push_back(create_history_entry(&format!("msg{}", i))); + } + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let valid_history_len = end - start; + let total_messages = valid_history_len * 2; + + // BUG: This WILL overflow (10002 > 10000) but trimming happens too late + if total_messages > MAX_CONVERSATION_STATE_HISTORY_LEN { + panic!("Expected overflow detection to trigger before sending {} messages (limit: {})", + total_messages, MAX_CONVERSATION_STATE_HISTORY_LEN); + } + } + + #[test] + fn test_overflow_with_context_buffer() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + // Create 4998 entries (9996 messages) - should be safe with 6-message buffer + for i in 0..4998 { + history.push_back(create_history_entry(&format!("msg{}", i))); + } + + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let valid_history_len = end - start; + let total_messages = valid_history_len * 2; + + // With 6 context messages, total would be 9996 + 6 = 10002 (OVERFLOW!) + let total_with_context = total_messages + 6; + assert!( + total_with_context <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "BUG: Total messages with context {} exceeds limit {} - trimming should have occurred earlier", + total_with_context, + MAX_CONVERSATION_STATE_HISTORY_LEN + ); + } + + #[test] + fn test_trimming_threshold() { + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + // Test the exact threshold where trimming should occur + // Current buggy condition: (history.len() * 2) > MAX_CONVERSATION_STATE_HISTORY_LEN - 6 + // This means: history.len() > 4997 + + // At 4997 entries: 4997 * 2 = 9994, which is NOT > 9994, so no trimming + for i in 0..4997 { + history.push_back(create_history_entry(&format!("msg{}", i))); + } + + let initial_len = history.len(); + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let trimmed = start > 0; + + assert!( + !trimmed, + "BUG: At {} entries (9994 messages), trimming should not occur yet, but it did", + initial_len + ); + + // At 4998 entries: 4998 * 2 = 9996, which IS > 9994, so trimming occurs + history.push_back(create_history_entry("msg4997")); + let (start, end) = enforce_conversation_invariants(&mut history, &mut next_message, &tools); + let trimmed = start > 0; + + assert!( + trimmed, + "At 4998 entries (9996 messages), trimming should occur" + ); + + // But by now we've already exceeded the safe limit with context messages! + let valid_history_len = end - start; + let total_with_context = (valid_history_len * 2) + 6; + assert!( + total_with_context <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "BUG: After trimming, total {} still exceeds limit {}", + total_with_context, + MAX_CONVERSATION_STATE_HISTORY_LEN + ); + } + + #[test] + fn test_proactive_trimming_needed() { + // This test demonstrates what the threshold SHOULD be + let safe_threshold = (MAX_CONVERSATION_STATE_HISTORY_LEN - 100) / 2; // Leave 100 message buffer + + let mut history = VecDeque::new(); + let mut next_message = None; + let tools: HashMap> = HashMap::new(); + + // Fill to the safe threshold + for i in 0..safe_threshold { + history.push_back(create_history_entry(&format!("msg{}", i))); + } + + let total_messages = history.len() * 2; + let buffer_remaining = MAX_CONVERSATION_STATE_HISTORY_LEN - total_messages; + + assert!( + buffer_remaining >= 100, + "Safe threshold should leave at least 100 messages of buffer, but only {} remaining", + buffer_remaining + ); + } +} diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 099c0c8761..d126a4564a 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -22,6 +22,8 @@ mod parser; mod prompt; mod prompt_parser; pub mod server_messenger; +#[cfg(test)] +mod test_blocking; use crate::cli::chat::checkpoint::CHECKPOINT_MESSAGE_MAX_LENGTH; use crate::constants::ui_text; #[cfg(unix)] @@ -783,6 +785,7 @@ impl ChatSession { loop { match ctrl_c().await { Ok(_) => { + warn!("🔴 CTRL+C SIGNAL RECEIVED - broadcasting to {} subscribers", ctrlc_tx.receiver_count()); let _ = ctrlc_tx .send(()) .map_err(|err| error!(?err, "failed to send ctrlc to broadcast channel")); @@ -827,6 +830,7 @@ impl ChatSession { let mut ctrl_c_stream = self.ctrlc_rx.resubscribe(); let result = match self.inner.take().expect("state must always be Some") { ChatState::PromptUser { skip_printing_tools } => { + warn!("🟡 Entering PromptUser state"); match (self.interactive, self.tool_uses.is_empty()) { (false, true) => { self.inner = Some(ChatState::Exit); @@ -843,7 +847,10 @@ impl ChatSession { ChatState::HandleInput { input } => { tokio::select! { res = self.handle_input(os, input) => res, - Ok(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: Some(self.tool_uses.clone()) }) + Ok(_) = ctrl_c_stream.recv() => { + warn!("🔴 CTRL+C caught in HandleInput state"); + Err(ChatError::Interrupted { tool_uses: Some(self.tool_uses.clone()) }) + } } }, ChatState::CompactHistory { @@ -858,7 +865,10 @@ impl ChatSession { let tool_uses_clone = self.tool_uses.clone(); tokio::select! { res = self.tool_use_execute(os) => res, - Ok(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: Some(tool_uses_clone) }) + Ok(_) = ctrl_c_stream.recv() => { + warn!("🔴 CTRL+C caught in ExecuteTools state"); + Err(ChatError::Interrupted { tool_uses: Some(tool_uses_clone) }) + } } }, ChatState::ValidateTools { tool_uses } => { @@ -874,6 +884,7 @@ impl ChatSession { tokio::select! { res = self.handle_response(os, conversation_state, request_metadata_clone) => res, Ok(_) = ctrl_c_stream.recv() => { + warn!("🔴 CTRL+C caught in HandleResponseStream state"); debug!(?request_metadata, "ctrlc received"); // Wait for handle_response to finish handling the ctrlc. tokio::time::sleep(Duration::from_millis(5)).await; @@ -949,7 +960,7 @@ impl ChatSession { .abandon_tool_use(tool_uses, "The user interrupted the tool execution.".to_string()); let _ = self .conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?; self.conversation.push_assistant_message( os, @@ -1767,7 +1778,7 @@ impl ChatSession { if should_retry { Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?, )) } else { @@ -2112,13 +2123,13 @@ impl ChatSession { } } - self.conversation.append_user_transcript(&user_input); Ok(ChatState::HandleInput { input: user_input }) } async fn handle_input(&mut self, os: &mut Os, mut user_input: String) -> Result { queue!(self.stderr, style::Print('\n'))?; user_input = sanitize_unicode_tags(&user_input); + self.conversation.append_user_transcript(&user_input); let input = user_input.trim(); // handle image path @@ -2347,7 +2358,7 @@ impl ChatSession { let conv_state = self .conversation - .as_sendable_conversation_state(os, &mut self.stderr, true) + .as_sendable_conversation_state(os, &mut self.stderr, true, self.ctrlc_rx.resubscribe()) .await?; self.send_tool_use_telemetry(os).await; @@ -2767,6 +2778,7 @@ impl ChatSession { os, None, Some(tool_context), + self.ctrlc_rx.resubscribe(), ) .await; } @@ -2797,7 +2809,7 @@ impl ChatSession { self.send_tool_use_telemetry(os).await; return Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?, )); } @@ -2815,6 +2827,16 @@ impl ChatSession { state: crate::api_client::model::ConversationState, request_metadata_lock: Arc>>, ) -> Result { + // Check if we should proactively compact at 98% token usage + if self.conversation.should_compact_proactively(98) { + warn!("Token usage at 98% threshold, triggering proactive compaction"); + return Ok(ChatState::CompactHistory { + prompt: None, + show_summary: false, + strategy: CompactStrategy::default(), + }); + } + let mut rx = self.send_message(os, state, request_metadata_lock, None).await?; let request_id = rx.request_id().map(String::from); @@ -2964,7 +2986,7 @@ impl ChatSession { self.send_tool_use_telemetry(os).await; return Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?, )); }, @@ -3001,7 +3023,7 @@ impl ChatSession { self.send_tool_use_telemetry(os).await; return Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?, )); }, @@ -3051,7 +3073,7 @@ impl ChatSession { self.send_tool_use_telemetry(os).await; return Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?, )); }, @@ -3256,6 +3278,7 @@ impl ChatSession { os, None, None, + self.ctrlc_rx.resubscribe(), ) .await; } @@ -3366,7 +3389,7 @@ impl ChatSession { return Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?, )); } @@ -3393,6 +3416,7 @@ impl ChatSession { os, None, // prompt Some(tool_context), + self.ctrlc_rx.resubscribe(), ) .await?; @@ -3437,7 +3461,7 @@ impl ChatSession { self.conversation.add_tool_results(tool_results); return Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, false) + .as_sendable_conversation_state(os, &mut self.stderr, false, self.ctrlc_rx.resubscribe()) .await?, )); } @@ -3472,7 +3496,7 @@ impl ChatSession { Ok(ChatState::HandleResponseStream( self.conversation - .as_sendable_conversation_state(os, &mut self.stderr, true) + .as_sendable_conversation_state(os, &mut self.stderr, true, self.ctrlc_rx.resubscribe()) .await?, )) } @@ -3561,16 +3585,20 @@ impl ChatSession { /// Helper function to read user input with a prompt and Ctrl+C handling fn read_user_input(&mut self, prompt: &str, exit_on_single_ctrl_c: bool) -> Option { + warn!("🟢 Entering read_user_input, about to call rustyline"); let mut ctrl_c = false; loop { + warn!("🟢 Calling input_source.read_line()..."); match (self.input_source.read_line(Some(prompt)), ctrl_c) { (Ok(Some(line)), _) => { + warn!("🟢 Got input: {:?}", line); if line.trim().is_empty() { continue; // Reprompt if the input is empty } return Some(line); }, (Ok(None), false) => { + warn!("🟢 Got None (Ctrl+C or Ctrl+D), ctrl_c={}", ctrl_c); if exit_on_single_ctrl_c { return None; } @@ -3584,8 +3612,14 @@ impl ChatSession { .unwrap_or_default(); ctrl_c = true; }, - (Ok(None), true) => return None, // Exit if Ctrl+C was pressed twice - (Err(_), _) => return None, + (Ok(None), true) => { + warn!("🟢 Got None again, exiting"); + return None; + }, + (Err(e), _) => { + warn!("🟢 Got error: {:?}", e); + return None; + }, } } } @@ -4696,4 +4730,52 @@ mod tests { assert_eq!(actual, *expected, "expected {} for input {}", expected, input); } } + + #[tokio::test] + async fn test_transcript_appended_before_sanitization() { + let mut os = Os::new().await.unwrap(); + os.client.set_mock_output(serde_json::json!([ + ["Response"], + ])); + + let agents = get_test_agents(&os).await; + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + + // Input with a hidden unicode character that should be sanitized + let input_with_hidden_char = format!("test{}", '\u{200B}'); // Zero-width space + + let mut session = ChatSession::new( + &mut os, + "test_conv_id", + agents, + None, + InputSource::new_mock(vec![ + input_with_hidden_char.clone(), + "exit".to_string(), + ]), + false, + || Some(80), + tool_manager, + None, + tool_config, + true, + false, + None, + ) + .await + .unwrap(); + + session.spawn(&mut os).await.unwrap(); + + // Check transcript - the hidden character should NOT be in the transcript + // because sanitization should happen BEFORE appending to transcript + let transcript = &session.conversation.transcript; + let user_msg = transcript.iter().find(|msg| msg.starts_with("> test")).unwrap(); + + // This will FAIL with current code because transcript is appended before sanitization + assert!(!user_msg.contains('\u{200B}'), + "Transcript should not contain hidden unicode characters. Current behavior: transcript is appended BEFORE sanitization"); + } } diff --git a/crates/chat-cli/src/cli/chat/parser.rs b/crates/chat-cli/src/cli/chat/parser.rs index eea45d891d..abd0d4ac5f 100644 --- a/crates/chat-cli/src/cli/chat/parser.rs +++ b/crates/chat-cli/src/cli/chat/parser.rs @@ -320,6 +320,9 @@ struct ResponseParser { received_response_size: usize, time_to_first_chunk: Option, time_between_chunks: Vec, + /// Token usage from MetadataEvent + input_tokens: Option, + output_tokens: Option, } impl ResponseParser { @@ -353,6 +356,8 @@ impl ResponseParser { received_response_size: 0, time_to_first_chunk: None, time_between_chunks: Vec::new(), + input_tokens: None, + output_tokens: None, request_metadata, cancel_token, } @@ -612,6 +617,12 @@ impl ResponseParser { ChatResponseStream::ToolUseEvent { input, .. } => { self.received_response_size += input.as_ref().map(String::len).unwrap_or_default(); }, + ChatResponseStream::MetadataEvent { usage } => { + if let Some(u) = usage { + self.input_tokens = u.input_tokens; + self.output_tokens = u.output_tokens; + } + }, _ => { warn!(?r, "received unexpected event from the response stream"); }, @@ -659,6 +670,8 @@ impl ResponseParser { .map(|t| (t.id.clone(), t.name.clone())) .collect::<_>(), model_id: self.model_id.clone(), + input_tokens: self.input_tokens, + output_tokens: self.output_tokens, } } } @@ -710,6 +723,9 @@ pub struct RequestMetadata { pub model_id: Option, /// Meta tags for the request. pub message_meta_tags: Vec, + /// Token usage for this request (input + output tokens) + pub input_tokens: Option, + pub output_tokens: Option, } fn system_time_to_unix_ms(time: SystemTime) -> u64 { diff --git a/crates/chat-cli/src/cli/chat/test_blocking.rs b/crates/chat-cli/src/cli/chat/test_blocking.rs new file mode 100644 index 0000000000..39e4943dcb --- /dev/null +++ b/crates/chat-cli/src/cli/chat/test_blocking.rs @@ -0,0 +1,138 @@ +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::time::Duration; + use tokio::sync::broadcast; + + /// This test demonstrates that blocking operations prevent tokio::select! from processing + /// Ctrl+C signals. This simulates the issue where rustyline::readline() blocks the runtime. + #[tokio::test] + async fn test_blocking_prevents_ctrlc_handling() { + let (ctrlc_tx, mut ctrlc_rx) = broadcast::channel::<()>(1); + let hang_detected = Arc::new(AtomicBool::new(false)); + let hang_detected_clone = hang_detected.clone(); + + // Spawn a task that simulates the blocking readline call + let blocking_task = tokio::spawn(async move { + tokio::select! { + // This simulates a blocking operation like rustyline::readline() + _ = tokio::task::spawn_blocking(|| { + // Block for 2 seconds to simulate readline waiting for input + std::thread::sleep(Duration::from_secs(2)); + }) => { + // If we get here, the blocking operation completed + }, + // This should fire when Ctrl+C is pressed + Ok(_) = ctrlc_rx.recv() => { + // If we get here, Ctrl+C was processed successfully + hang_detected_clone.store(true, Ordering::SeqCst); + }, + } + }); + + // Wait a bit for the task to start blocking + tokio::time::sleep(Duration::from_millis(100)).await; + + // Simulate Ctrl+C being pressed + let _ = ctrlc_tx.send(()); + + // Wait a bit to see if the signal was processed + tokio::time::sleep(Duration::from_millis(200)).await; + + // The hang_detected flag should be true if Ctrl+C was handled + // With spawn_blocking, this should work correctly + assert!( + hang_detected.load(Ordering::SeqCst), + "Ctrl+C should be processed even during blocking operations when using spawn_blocking" + ); + + // Clean up + blocking_task.abort(); + } + + /// This test simulates the run_hooks scenario where async operations + /// complete without Ctrl+C handling - demonstrates the bug + #[tokio::test(flavor = "multi_thread")] + async fn test_hooks_without_ctrlc_handling() { + let (ctrlc_tx, _ctrlc_rx) = broadcast::channel::<()>(1); + let hook_completed = Arc::new(AtomicBool::new(false)); + let hook_completed_clone = hook_completed.clone(); + + // Simulate hook execution that takes time but has NO Ctrl+C handling + let hook_task = tokio::spawn(async move { + // Simulate a hook that takes 1 second + tokio::time::sleep(Duration::from_secs(1)).await; + hook_completed_clone.store(true, Ordering::SeqCst); + }); + + // Simulate Ctrl+C being pressed immediately + tokio::time::sleep(Duration::from_millis(50)).await; + let _ = ctrlc_tx.send(()); + + // Wait a bit to see if hook responds to Ctrl+C (it won't) + tokio::time::sleep(Duration::from_millis(200)).await; + + // Hook should still be running because it doesn't check for Ctrl+C + assert!( + !hook_completed.load(Ordering::SeqCst), + "Hook should still be running - doesn't respond to Ctrl+C" + ); + + // Wait for hook to actually complete + let _ = hook_task.await; + + // Now it should be done + assert!( + hook_completed.load(Ordering::SeqCst), + "Hook eventually completes but ignored Ctrl+C - this is the bug" + ); + } + + /// This test shows the CORRECT pattern - hooks WITH Ctrl+C handling + #[tokio::test(flavor = "multi_thread")] + async fn test_hooks_with_ctrlc_handling() { + let (ctrlc_tx, mut ctrlc_rx) = broadcast::channel::<()>(1); + + // Simulate hook execution WITH Ctrl+C handling + let hook_task = tokio::spawn(async move { + let mut futures = vec![]; + for _ in 0..3 { + futures.push(tokio::spawn(async { + tokio::time::sleep(Duration::from_secs(2)).await; + "hook result" + })); + } + + // Wait for hooks WITH tokio::select! for Ctrl+C + tokio::select! { + _ = async { + for fut in futures { + let _ = fut.await; + } + } => {}, + Ok(_) = ctrlc_rx.recv() => { + // Ctrl+C received, cancel hooks + return Err("Cancelled by Ctrl+C"); + } + } + Ok("Completed") + }); + + // Simulate Ctrl+C being pressed while hooks are running + tokio::time::sleep(Duration::from_millis(100)).await; + let _ = ctrlc_tx.send(()); + + // Hook task should complete quickly because it handles Ctrl+C + let result = tokio::time::timeout(Duration::from_millis(500), hook_task).await; + + assert!( + result.is_ok(), + "Hook execution should complete quickly when Ctrl+C is handled" + ); + assert!( + result.unwrap().unwrap().is_err(), + "Hook should return error indicating cancellation" + ); + } +}