diff --git a/Cargo.lock b/Cargo.lock index 57bbd15efc8140..216da9d00528cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2795,6 +2795,67 @@ dependencies = [ "memchr", ] +[[package]] +name = "cherrypick_agent" +version = "0.1.0" +dependencies = [ + "async-trait", + "chrono", + "futures 0.3.31", + "git2", + "notify 7.0.0", + "rusqlite", + "serde", + "serde_json", + "sha2", + "tempfile", + "thiserror 2.0.17", + "tiktoken-rs", + "tokio", + "tokio-rusqlite", + "zed-reqwest", +] + +[[package]] +name = "cherrypick_pr" +version = "0.1.0" +dependencies = [ + "chrono", + "git2", + "lru", + "notify 7.0.0", + "rusqlite", + "serde", + "serde_json", + "sha2", + "tempfile", + "thiserror 2.0.17", + "tokio", + "tokio-rusqlite", +] + +[[package]] +name = "cherrypick_ui" +version = "0.1.0" +dependencies = [ + "anyhow", + "askpass", + "cherrypick_agent", + "cherrypick_pr", + "git", + "git2", + "git_graph", + "git_ui", + "gpui", + "log", + "project", + "settings", + "sha2", + "tokio", + "ui", + "workspace", +] + [[package]] name = "chrono" version = "0.4.42" @@ -6150,6 +6211,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.16.2" @@ -8138,6 +8205,15 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "hashlink" version = "0.10.0" @@ -8892,6 +8968,17 @@ dependencies = [ "libc", ] +[[package]] +name = "inotify" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd168d97690d0b8c412d6b6c10360277f4d7ee495c5d0d5d5fe0854923255cc" +dependencies = [ + "bitflags 1.3.2", + "inotify-sys", + "libc", +] + [[package]] name = "inotify" version = "0.11.0" @@ -11202,6 +11289,25 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "notify" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c533b4c39709f9ba5005d8002048266593c1cfaf3c5f0739d5b8ab0c6c504009" +dependencies = [ + "bitflags 2.10.0", + "filetime", + "fsevent-sys", + "inotify 0.10.2", + "kqueue", + "libc", + "log", + "mio 1.1.0", + "notify-types 1.0.1", + "walkdir", + "windows-sys 0.52.0", +] + [[package]] name = "notify" version = "8.2.0" @@ -11214,7 +11320,7 @@ dependencies = [ "libc", "log", "mio 1.1.0", - "notify-types", + "notify-types 2.0.0", "walkdir", "windows-sys 0.60.2", ] @@ -11230,6 +11336,15 @@ dependencies = [ "notify 6.1.1", ] +[[package]] +name = "notify-types" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "585d3cb5e12e01aed9e8a1f70d5c6b5e86fe2a6e48fc8cd0b3e0b8df6f6eb174" +dependencies = [ + "instant", +] + [[package]] name = "notify-types" version = "2.0.0" @@ -15117,6 +15232,20 @@ dependencies = [ "zeromq", ] +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.10.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink 0.9.1", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rust-embed" version = "8.11.0" @@ -18223,6 +18352,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rusqlite" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65501378eb676f400c57991f42cbd0986827ab5c5200c53f206d710fb32a945" +dependencies = [ + "crossbeam-channel", + "rusqlite", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -22380,6 +22520,7 @@ dependencies = [ "breadcrumbs", "call", "channel", + "cherrypick_ui", "chrono", "clap", "cli", diff --git a/Cargo.toml b/Cargo.toml index 2fac513d4ee395..5b13c5b5b162c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,9 @@ members = [ "crates/buffer_diff", "crates/call", "crates/channel", + "crates/cherrypick_ui", + "crates/cherrypick_pr", + "crates/cherrypick_agent", "crates/cli", "crates/client", "crates/clock", @@ -284,6 +287,9 @@ breadcrumbs = { path = "crates/breadcrumbs" } buffer_diff = { path = "crates/buffer_diff" } call = { path = "crates/call" } channel = { path = "crates/channel" } +cherrypick_ui = { path = "crates/cherrypick_ui" } +cherrypick_pr = { path = "crates/cherrypick_pr" } +cherrypick_agent = { path = "crates/cherrypick_agent" } cli = { path = "crates/cli" } client = { path = "crates/client" } clock = { path = "crates/clock" } @@ -614,6 +620,7 @@ linkify = "0.10.0" libwebrtc = "0.3.26" livekit = { version = "0.7.32", features = ["tokio", "rustls-tls-native-roots"] } log = { version = "0.4.16", features = ["kv_unstable_serde", "serde"] } +lru = "0.12" lsp-types = { git = "https://github.com/zed-industries/lsp-types", rev = "f4dfa89a21ca35cd929b70354b1583fabae325f8" } mach2 = "0.5" markup5ever_rcdom = "0.3.0" @@ -623,6 +630,7 @@ moka = { version = "0.12.10", features = ["sync"] } nanoid = "0.4" nbformat = "1.2.0" nix = "0.29" +notify = "7" nucleo = "0.5" num-format = "0.4.4" objc = "0.2" @@ -699,6 +707,7 @@ rsa = "0.9.6" runtimelib = { version = "1.4.0", default-features = false, features = [ "async-dispatcher-runtime", "aws-lc-rs" ] } +rusqlite = { version = "0.32", features = ["bundled"] } rust-embed = { version = "8.11", features = ["include-exclude"] } rustc-hash = "2.1.0" rustls = { version = "0.23.26" } @@ -733,6 +742,7 @@ sysinfo = "0.37.0" take-until = "0.2.0" tempfile = "3.20.0" thiserror = "2.0.12" +tiktoken-rs = { git = "https://github.com/zed-industries/tiktoken-rs", rev = "2570c4387a8505fb8f1d3f3557454b474f1e8271" } time = { version = "0.3", features = [ "macros", "parsing", @@ -743,6 +753,7 @@ time = { version = "0.3", features = [ ] } tiny_http = "0.12" tokio = { version = "1" } +tokio-rusqlite = "0.6" tokio-socks = { version = "0.5.2", default-features = false, features = [ "futures-io", "tokio", diff --git a/crates/cherrypick_agent/Cargo.toml b/crates/cherrypick_agent/Cargo.toml new file mode 100644 index 00000000000000..d06a21aa1ac12f --- /dev/null +++ b/crates/cherrypick_agent/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "cherrypick_agent" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/lib.rs" + +[dependencies] +git2.workspace = true +tokio = { workspace = true, features = ["rt", "sync", "macros", "time", "process", "io-util"] } +thiserror.workspace = true +serde.workspace = true +serde_json.workspace = true +reqwest.workspace = true +async-trait.workspace = true +futures.workspace = true +chrono.workspace = true +rusqlite.workspace = true +tokio-rusqlite.workspace = true +tiktoken-rs.workspace = true +notify.workspace = true +sha2.workspace = true + +[dev-dependencies] +tempfile.workspace = true +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/crates/cherrypick_agent/src/chat/mod.rs b/crates/cherrypick_agent/src/chat/mod.rs new file mode 100644 index 00000000000000..2e8c6c83df6ed2 --- /dev/null +++ b/crates/cherrypick_agent/src/chat/mod.rs @@ -0,0 +1,295 @@ +pub mod store; + +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::{mpsc, watch}; + +use crate::context::ContextEngine; +use crate::error::{AgentError, Result}; +use crate::provider::LlmProvider; +use crate::provider::types::{ + CompletionRequest, Message, MessageContent, RiskLevel, StreamChunk, ToolCall, +}; +use crate::tools::ToolExecutor; + +const MAX_ITERATIONS: u32 = 10; +const MAX_DURATION: Duration = Duration::from_secs(120); +const SYSTEM_POLICY: &str = r#"You are CherryPick AI, a git-aware coding assistant integrated into the CherryPick git client. You help users understand their repositories, review changes, write commits, and manage branches. + +Safety rules (immutable): +- Never force-push without explicit user confirmation +- Never delete branches without explicit user confirmation +- Never modify files outside the repository working directory +- Never read or expose sensitive files (.env, credentials, keys) +- Always explain what you're about to do before executing write operations"#; + +#[derive(Debug, Clone)] +pub enum AgentEvent { + TextDelta(String), + ToolCallStarted { + id: String, + name: String, + risk_level: RiskLevel, + }, + ToolCallCompleted { + id: String, + result: String, + is_error: bool, + }, + ConfirmationNeeded { + tool_name: String, + risk_level: RiskLevel, + preview: String, + }, + TurnComplete, + Error(String), +} + +pub struct AgentConfig { + pub model: String, + pub max_tokens: u32, + pub temperature: Option, + pub max_iterations: u32, + pub max_duration: Duration, +} + +impl Default for AgentConfig { + fn default() -> Self { + Self { + model: "claude-sonnet-4-20250514".to_string(), + max_tokens: 4096, + temperature: None, + max_iterations: MAX_ITERATIONS, + max_duration: MAX_DURATION, + } + } +} + +pub struct AgentService { + provider: Arc, + tool_executor: ToolExecutor, + context_engine: ContextEngine, + config: AgentConfig, + history: Vec, +} + +impl AgentService { + pub fn new( + provider: Arc, + tool_executor: ToolExecutor, + context_engine: ContextEngine, + config: AgentConfig, + ) -> Self { + Self { + provider, + tool_executor, + context_engine, + config, + history: Vec::new(), + } + } + + pub async fn send_message( + &mut self, + user_message: &str, + repo_path: &Path, + event_tx: mpsc::UnboundedSender, + cancel: watch::Receiver, + ) -> Result<()> { + self.history.push(Message::user(user_message)); + + let tool_defs = self.tool_executor.definitions(); + let start = std::time::Instant::now(); + let mut iterations = 0u32; + + loop { + if *cancel.borrow() { + return Err(AgentError::Cancelled); + } + + if iterations >= self.config.max_iterations { + return Err(AgentError::MaxIterations(self.config.max_iterations)); + } + + if start.elapsed() > self.config.max_duration { + return Err(AgentError::MaxDuration); + } + + iterations += 1; + + let truncated = ContextEngine::truncate_history( + &self.history, + self.context_engine.budget().warm, + ); + + let request = CompletionRequest { + model: self.config.model.clone(), + messages: truncated, + system: Some(SYSTEM_POLICY.to_string()), + tools: tool_defs.clone(), + max_tokens: self.config.max_tokens, + temperature: self.config.temperature, + }; + + let (chunk_tx, mut chunk_rx) = mpsc::unbounded_channel(); + + let provider = self.provider.clone(); + let provider_handle = tokio::spawn(async move { + provider.stream_completion(request, chunk_tx).await + }); + + let mut text_buffer = String::new(); + let mut tool_calls: Vec = Vec::new(); + let mut current_tool_id = String::new(); + let mut current_tool_name = String::new(); + let mut current_tool_args = String::new(); + + while let Some(chunk) = chunk_rx.recv().await { + if *cancel.borrow() { + return Err(AgentError::Cancelled); + } + + match chunk { + StreamChunk::TextDelta(text) => { + text_buffer.push_str(&text); + let _ = event_tx.send(AgentEvent::TextDelta(text)); + } + StreamChunk::ToolCallStart { id, name } => { + current_tool_id = id; + current_tool_name = name.clone(); + current_tool_args.clear(); + let risk = self + .tool_executor + .risk_level(&name) + .unwrap_or(RiskLevel::ReadOnly); + let _ = event_tx.send(AgentEvent::ToolCallStarted { + id: current_tool_id.clone(), + name, + risk_level: risk, + }); + } + StreamChunk::ToolCallDelta(json) => { + current_tool_args.push_str(&json); + } + StreamChunk::ToolCallEnd => { + let args: serde_json::Value = + serde_json::from_str(¤t_tool_args).unwrap_or_default(); + tool_calls.push(ToolCall { + id: current_tool_id.clone(), + name: current_tool_name.clone(), + arguments: args, + }); + } + StreamChunk::Done => break, + StreamChunk::Error(e) => { + let _ = event_tx.send(AgentEvent::Error(e.clone())); + return Err(AgentError::Provider(e)); + } + StreamChunk::Usage(_) => {} + } + } + + let _ = provider_handle.await; + + let mut assistant_content = Vec::new(); + if !text_buffer.is_empty() { + assistant_content.push(MessageContent::Text { + text: text_buffer.clone(), + }); + } + for tc in &tool_calls { + assistant_content.push(MessageContent::ToolUse { + id: tc.id.clone(), + name: tc.name.clone(), + input: tc.arguments.clone(), + }); + } + + if !assistant_content.is_empty() { + self.history.push(Message { + role: crate::provider::types::Role::Assistant, + content: assistant_content, + }); + } + + if tool_calls.is_empty() { + let _ = event_tx.send(AgentEvent::TurnComplete); + return Ok(()); + } + + for tc in &tool_calls { + let risk = self + .tool_executor + .risk_level(&tc.name) + .unwrap_or(RiskLevel::ReadOnly); + + if risk.requires_confirmation() { + let preview = self + .tool_executor + .preview(tc) + .unwrap_or_else(|| tc.name.clone()); + let _ = event_tx.send(AgentEvent::ConfirmationNeeded { + tool_name: tc.name.clone(), + risk_level: risk, + preview, + }); + // For now, auto-approve in the agentic loop. + // In the full UI implementation, this would wait for user input + // via a oneshot channel before proceeding. + } + + let result = self.tool_executor.execute(tc, repo_path).await; + let (output, is_error) = match result { + Ok(output) => { + let _ = event_tx.send(AgentEvent::ToolCallCompleted { + id: tc.id.clone(), + result: output.clone(), + is_error: false, + }); + (output, false) + } + Err(e) => { + let err_msg = e.to_string(); + let _ = event_tx.send(AgentEvent::ToolCallCompleted { + id: tc.id.clone(), + result: err_msg.clone(), + is_error: true, + }); + (err_msg, true) + } + }; + + self.history + .push(Message::tool_result(&tc.id, &output, is_error)); + } + } + } + + pub fn history(&self) -> &[Message] { + &self.history + } + + pub fn clear_history(&mut self) { + self.history.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_config() { + let config = AgentConfig::default(); + assert_eq!(config.max_iterations, MAX_ITERATIONS); + assert!(config.model.contains("claude")); + } + + #[test] + fn system_policy_contains_safety_rules() { + assert!(SYSTEM_POLICY.contains("Never force-push")); + assert!(SYSTEM_POLICY.contains("sensitive files")); + } +} diff --git a/crates/cherrypick_agent/src/chat/store.rs b/crates/cherrypick_agent/src/chat/store.rs new file mode 100644 index 00000000000000..8b840a86e3bb53 --- /dev/null +++ b/crates/cherrypick_agent/src/chat/store.rs @@ -0,0 +1,256 @@ +use chrono::Utc; +use rusqlite::params; +use tokio_rusqlite::Connection; + +use crate::error::{AgentError, Result}; + +const SCHEMA: &str = r#" +CREATE TABLE IF NOT EXISTS conversations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repo_path TEXT, + title TEXT NOT NULL DEFAULT 'New Chat', + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + conversation_id INTEGER NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + role TEXT NOT NULL, + content_json TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id); +"#; + +#[derive(Debug, Clone)] +pub struct Conversation { + pub id: i64, + pub repo_path: Option, + pub title: String, + pub created_at: String, + pub updated_at: String, +} + +#[derive(Debug, Clone)] +pub struct StoredMessage { + pub id: i64, + pub conversation_id: i64, + pub role: String, + pub content_json: String, + pub created_at: String, +} + +pub struct ChatStore { + conn: Connection, +} + +impl ChatStore { + pub async fn open(path: &str) -> Result { + let conn = Connection::open(path) + .await + .map_err(|e| AgentError::Database(e.to_string()))?; + conn.call(|conn| { + conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?; + conn.execute_batch(SCHEMA)?; + Ok(()) + }) + .await + .map_err(|e| AgentError::Database(e.to_string()))?; + Ok(Self { conn }) + } + + pub async fn open_in_memory() -> Result { + let conn = Connection::open_in_memory() + .await + .map_err(|e| AgentError::Database(e.to_string()))?; + conn.call(|conn| { + conn.execute_batch("PRAGMA foreign_keys=ON;")?; + conn.execute_batch(SCHEMA)?; + Ok(()) + }) + .await + .map_err(|e| AgentError::Database(e.to_string()))?; + Ok(Self { conn }) + } + + pub async fn create_conversation(&self, repo_path: Option<&str>, title: &str) -> Result { + let rp = repo_path.map(String::from); + let title = title.to_string(); + let now = Utc::now().to_rfc3339(); + self.conn + .call(move |conn| { + conn.execute( + "INSERT INTO conversations (repo_path, title, created_at, updated_at) VALUES (?1, ?2, ?3, ?3)", + params![rp, title, now], + )?; + Ok(conn.last_insert_rowid()) + }) + .await + .map_err(|e| AgentError::Database(e.to_string())) + } + + pub async fn list_conversations(&self, repo_path: Option<&str>) -> Result> { + let rp = repo_path.map(String::from); + self.conn + .call(move |conn| { + let mut stmt = if rp.is_some() { + conn.prepare( + "SELECT id, repo_path, title, created_at, updated_at FROM conversations + WHERE repo_path = ?1 ORDER BY updated_at DESC", + )? + } else { + conn.prepare( + "SELECT id, repo_path, title, created_at, updated_at FROM conversations + ORDER BY updated_at DESC", + )? + }; + + let rows = if let Some(ref rp) = rp { + stmt.query_map(params![rp], map_conversation)? + } else { + stmt.query_map([], map_conversation)? + }; + + let mut convos = Vec::new(); + for row in rows { + convos.push(row?); + } + Ok(convos) + }) + .await + .map_err(|e| AgentError::Database(e.to_string())) + } + + pub async fn save_message( + &self, + conversation_id: i64, + role: &str, + content_json: &str, + ) -> Result { + let role = role.to_string(); + let content = content_json.to_string(); + let now = Utc::now().to_rfc3339(); + self.conn + .call(move |conn| { + conn.execute( + "INSERT INTO messages (conversation_id, role, content_json, created_at) + VALUES (?1, ?2, ?3, ?4)", + params![conversation_id, role, content, now], + )?; + conn.execute( + "UPDATE conversations SET updated_at = ?1 WHERE id = ?2", + params![now, conversation_id], + )?; + Ok(conn.last_insert_rowid()) + }) + .await + .map_err(|e| AgentError::Database(e.to_string())) + } + + pub async fn load_messages(&self, conversation_id: i64) -> Result> { + self.conn + .call(move |conn| { + let mut stmt = conn.prepare( + "SELECT id, conversation_id, role, content_json, created_at + FROM messages WHERE conversation_id = ?1 ORDER BY id ASC", + )?; + let rows = stmt.query_map(params![conversation_id], |row| { + Ok(StoredMessage { + id: row.get(0)?, + conversation_id: row.get(1)?, + role: row.get(2)?, + content_json: row.get(3)?, + created_at: row.get(4)?, + }) + })?; + let mut msgs = Vec::new(); + for row in rows { + msgs.push(row?); + } + Ok(msgs) + }) + .await + .map_err(|e| AgentError::Database(e.to_string())) + } + + pub async fn delete_conversation(&self, id: i64) -> Result<()> { + self.conn + .call(move |conn| { + conn.execute("DELETE FROM conversations WHERE id = ?1", params![id])?; + Ok(()) + }) + .await + .map_err(|e| AgentError::Database(e.to_string())) + } +} + +fn map_conversation(row: &rusqlite::Row) -> rusqlite::Result { + Ok(Conversation { + id: row.get(0)?, + repo_path: row.get(1)?, + title: row.get(2)?, + created_at: row.get(3)?, + updated_at: row.get(4)?, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn create_and_list_conversations() { + let store = ChatStore::open_in_memory().await.unwrap(); + let id = store + .create_conversation(Some("/repo"), "Test Chat") + .await + .unwrap(); + assert!(id > 0); + + let convos = store.list_conversations(Some("/repo")).await.unwrap(); + assert_eq!(convos.len(), 1); + assert_eq!(convos[0].title, "Test Chat"); + } + + #[tokio::test] + async fn save_and_load_messages() { + let store = ChatStore::open_in_memory().await.unwrap(); + let conv_id = store + .create_conversation(None, "Chat") + .await + .unwrap(); + + store + .save_message(conv_id, "user", r#"[{"type":"text","text":"hello"}]"#) + .await + .unwrap(); + store + .save_message(conv_id, "assistant", r#"[{"type":"text","text":"hi"}]"#) + .await + .unwrap(); + + let msgs = store.load_messages(conv_id).await.unwrap(); + assert_eq!(msgs.len(), 2); + assert_eq!(msgs[0].role, "user"); + assert_eq!(msgs[1].role, "assistant"); + } + + #[tokio::test] + async fn delete_conversation_cascades() { + let store = ChatStore::open_in_memory().await.unwrap(); + let conv_id = store + .create_conversation(None, "Chat") + .await + .unwrap(); + store + .save_message(conv_id, "user", "[]") + .await + .unwrap(); + + store.delete_conversation(conv_id).await.unwrap(); + let msgs = store.load_messages(conv_id).await.unwrap(); + assert!(msgs.is_empty()); + } +} diff --git a/crates/cherrypick_agent/src/context/mod.rs b/crates/cherrypick_agent/src/context/mod.rs new file mode 100644 index 00000000000000..e63a9a047a15db --- /dev/null +++ b/crates/cherrypick_agent/src/context/mod.rs @@ -0,0 +1,199 @@ +use std::path::Path; + +use crate::provider::types::Message; + +const DEFAULT_TOKEN_BUDGET: usize = 100_000; +const HOT_TIER_RATIO: f32 = 0.20; +const WARM_TIER_RATIO: f32 = 0.50; +const COLD_TIER_RATIO: f32 = 0.30; + +pub struct ContextBudget { + pub total: usize, + pub hot: usize, + pub warm: usize, + pub cold: usize, +} + +impl ContextBudget { + pub fn new(total: usize) -> Self { + let total_f = total as f32; + Self { + total, + hot: (total_f * HOT_TIER_RATIO) as usize, + warm: (total_f * WARM_TIER_RATIO) as usize, + cold: (total_f * COLD_TIER_RATIO) as usize, + } + } +} + +impl Default for ContextBudget { + fn default() -> Self { + Self::new(DEFAULT_TOKEN_BUDGET) + } +} + +pub struct ContextEngine { + budget: ContextBudget, + repo_map: Option, +} + +impl ContextEngine { + pub fn new(budget: ContextBudget) -> Self { + Self { + budget, + repo_map: None, + } + } + + pub fn build_repo_map(&mut self, repo_path: &Path) { + let mut entries = Vec::new(); + collect_files(repo_path, repo_path, &mut entries, 200); + self.repo_map = Some(entries.join("\n")); + } + + pub fn repo_map(&self) -> Option<&str> { + self.repo_map.as_deref() + } + + pub fn estimate_tokens(text: &str) -> usize { + text.len() / 4 + } + + pub fn truncate_history( + messages: &[Message], + max_tokens: usize, + ) -> Vec { + let mut total = 0; + let mut result: Vec = Vec::new(); + + for msg in messages.iter().rev() { + let tokens = msg + .content + .iter() + .map(|c| match c { + crate::provider::types::MessageContent::Text { text } => { + Self::estimate_tokens(text) + } + crate::provider::types::MessageContent::ToolResult { content, .. } => { + Self::estimate_tokens(content) + } + crate::provider::types::MessageContent::ToolUse { input, .. } => { + Self::estimate_tokens(&input.to_string()) + } + }) + .sum::(); + + if total + tokens > max_tokens { + break; + } + total += tokens; + result.push(msg.clone()); + } + + result.reverse(); + result + } + + pub fn budget(&self) -> &ContextBudget { + &self.budget + } +} + +fn collect_files( + dir: &Path, + root: &Path, + entries: &mut Vec, + max: usize, +) { + if entries.len() >= max { + return; + } + + let read_dir = match std::fs::read_dir(dir) { + Ok(d) => d, + Err(_) => return, + }; + + for entry in read_dir { + let entry = match entry { + Ok(e) => e, + Err(_) => continue, + }; + + let path = entry.path(); + let name = path.file_name().and_then(|n| n.to_str()).unwrap_or(""); + + if name.starts_with('.') || name == "target" || name == "node_modules" { + continue; + } + + if path.is_dir() { + collect_files(&path, root, entries, max); + } else if path.is_file() { + let rel = path.strip_prefix(root).unwrap_or(&path); + entries.push(rel.to_string_lossy().to_string()); + } + + if entries.len() >= max { + break; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::provider::types::Message; + + #[test] + fn budget_allocation() { + let budget = ContextBudget::new(100_000); + assert_eq!(budget.hot, 20_000); + assert_eq!(budget.warm, 50_000); + assert_eq!(budget.cold, 30_000); + assert_eq!(budget.hot + budget.warm + budget.cold, budget.total); + } + + #[test] + fn estimate_tokens_rough() { + let text = "Hello, world! This is a test."; + let tokens = ContextEngine::estimate_tokens(text); + assert!(tokens > 0); + assert!(tokens < text.len()); + } + + #[test] + fn truncate_history_respects_budget() { + let messages: Vec = (0..100) + .map(|i| Message::user(&format!("Message {i} with some content"))) + .collect(); + let truncated = ContextEngine::truncate_history(&messages, 100); + assert!(truncated.len() < messages.len()); + assert!(!truncated.is_empty()); + } + + #[test] + fn truncate_preserves_order() { + let messages = vec![ + Message::user("first"), + Message::user("second"), + Message::user("third"), + ]; + let truncated = ContextEngine::truncate_history(&messages, 10000); + assert_eq!(truncated.len(), 3); + assert_eq!(truncated[0].text_content(), Some("first")); + } + + #[test] + fn repo_map_collection() { + let tmp = tempfile::tempdir().unwrap(); + std::fs::write(tmp.path().join("a.rs"), "fn main() {}").unwrap(); + std::fs::write(tmp.path().join("b.rs"), "fn test() {}").unwrap(); + + let mut engine = ContextEngine::new(ContextBudget::default()); + engine.build_repo_map(tmp.path()); + let map = engine.repo_map().unwrap(); + assert!(map.contains("a.rs")); + assert!(map.contains("b.rs")); + } +} diff --git a/crates/cherrypick_agent/src/error.rs b/crates/cherrypick_agent/src/error.rs new file mode 100644 index 00000000000000..cafbc195d8eb62 --- /dev/null +++ b/crates/cherrypick_agent/src/error.rs @@ -0,0 +1,79 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AgentError { + #[error("Provider error: {0}")] + Provider(String), + + #[error("Rate limited: retry after {retry_after_secs}s")] + RateLimited { retry_after_secs: u64 }, + + #[error("Context too long: {0} tokens exceeds budget")] + ContextTooLong(usize), + + #[error("Tool execution failed: {0}")] + ToolExecution(String), + + #[error("Tool not found: {0}")] + ToolNotFound(String), + + #[error("Tool call denied by user")] + ToolDenied, + + #[error("Max iterations ({0}) exceeded")] + MaxIterations(u32), + + #[error("Max duration exceeded")] + MaxDuration, + + #[error("Cancelled")] + Cancelled, + + #[error("Key not found for provider: {0}")] + KeyNotFound(String), + + #[error("Skill not found: {0}")] + SkillNotFound(String), + + #[error("Skill parse error: {0}")] + SkillParse(String), + + #[error("MCP error: {0}")] + Mcp(String), + + #[error("Database error: {0}")] + Database(String), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + #[error("HTTP error: {0}")] + Http(#[from] reqwest::Error), + + #[error("{0}")] + Other(String), +} + +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn error_display() { + let err = AgentError::ToolNotFound("test_tool".into()); + assert!(err.to_string().contains("test_tool")); + } + + #[test] + fn rate_limited_displays_retry() { + let err = AgentError::RateLimited { + retry_after_secs: 30, + }; + assert!(err.to_string().contains("30")); + } +} diff --git a/crates/cherrypick_agent/src/keys.rs b/crates/cherrypick_agent/src/keys.rs new file mode 100644 index 00000000000000..2fc77a02f5230f --- /dev/null +++ b/crates/cherrypick_agent/src/keys.rs @@ -0,0 +1,143 @@ +use crate::error::Result; + +pub trait KeyBackend: Send + Sync { + fn get(&self, key: &str) -> Result>; + fn set(&self, key: &str, value: &str) -> Result<()>; + fn delete(&self, key: &str) -> Result<()>; +} + +pub struct KeyManager { + backend: Box, +} + +impl KeyManager { + pub fn new(backend: Box) -> Self { + Self { backend } + } + + pub fn with_env() -> Self { + Self::new(Box::new(EnvVarBackend)) + } + + pub fn get_key(&self, provider: &str, profile: &str) -> Result> { + let key = format!("{provider}:{profile}"); + + if let Some(val) = self.backend.get(&key)? { + return Ok(Some(val)); + } + + let env_key = format!( + "{}_API_KEY", + provider.to_uppercase().replace('-', "_") + ); + if let Ok(val) = std::env::var(&env_key) { + return Ok(Some(val)); + } + + Ok(None) + } + + pub fn set_key(&self, provider: &str, profile: &str, value: &str) -> Result<()> { + let key = format!("{provider}:{profile}"); + self.backend.set(&key, value) + } + + pub fn delete_key(&self, provider: &str, profile: &str) -> Result<()> { + let key = format!("{provider}:{profile}"); + self.backend.delete(&key) + } + + pub fn has_key(&self, provider: &str, profile: &str) -> bool { + self.get_key(provider, profile) + .map(|k| k.is_some()) + .unwrap_or(false) + } +} + +struct EnvVarBackend; + +impl KeyBackend for EnvVarBackend { + fn get(&self, key: &str) -> Result> { + let env_key = key.replace(':', "_").to_uppercase(); + Ok(std::env::var(&env_key).ok()) + } + + fn set(&self, _key: &str, _value: &str) -> Result<()> { + Ok(()) + } + + fn delete(&self, _key: &str) -> Result<()> { + Ok(()) + } +} + +pub struct InMemoryBackend { + data: std::sync::Mutex>, +} + +impl InMemoryBackend { + pub fn new() -> Self { + Self { + data: std::sync::Mutex::new(std::collections::HashMap::new()), + } + } +} + +impl KeyBackend for InMemoryBackend { + fn get(&self, key: &str) -> Result> { + Ok(self.data.lock().unwrap().get(key).cloned()) + } + + fn set(&self, key: &str, value: &str) -> Result<()> { + self.data + .lock() + .unwrap() + .insert(key.to_string(), value.to_string()); + Ok(()) + } + + fn delete(&self, key: &str) -> Result<()> { + self.data.lock().unwrap().remove(key); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_manager() -> KeyManager { + KeyManager::new(Box::new(InMemoryBackend::new())) + } + + #[test] + fn set_and_get_key() { + let mgr = test_manager(); + mgr.set_key("anthropic", "default", "sk-test").unwrap(); + let key = mgr.get_key("anthropic", "default").unwrap(); + assert_eq!(key, Some("sk-test".to_string())); + } + + #[test] + fn get_nonexistent_key() { + let mgr = test_manager(); + let key = mgr.get_key("anthropic", "default").unwrap(); + assert!(key.is_none()); + } + + #[test] + fn delete_key() { + let mgr = test_manager(); + mgr.set_key("anthropic", "default", "sk-test").unwrap(); + mgr.delete_key("anthropic", "default").unwrap(); + assert!(!mgr.has_key("anthropic", "default")); + } + + #[test] + fn has_key_check() { + let mgr = test_manager(); + assert!(!mgr.has_key("anthropic", "default")); + mgr.set_key("anthropic", "default", "sk-test").unwrap(); + assert!(mgr.has_key("anthropic", "default")); + } +} diff --git a/crates/cherrypick_agent/src/lib.rs b/crates/cherrypick_agent/src/lib.rs new file mode 100644 index 00000000000000..6f9b1451748480 --- /dev/null +++ b/crates/cherrypick_agent/src/lib.rs @@ -0,0 +1,20 @@ +pub mod chat; +pub mod context; +pub mod error; +pub mod keys; +pub mod mcp; +pub mod provider; +pub mod skills; +pub mod tools; + +pub use chat::{AgentConfig, AgentEvent, AgentService}; +pub use context::ContextEngine; +pub use error::{AgentError, Result}; +pub use keys::KeyManager; +pub use mcp::{McpClient, McpClientConfig, McpManager}; +pub use provider::types::{ + CompletionRequest, Message, MessageContent, RiskLevel, Role, StreamChunk, ToolCall, + ToolDefinition, Usage, +}; +pub use skills::SkillLoader; +pub use tools::ToolExecutor; diff --git a/crates/cherrypick_agent/src/mcp.rs b/crates/cherrypick_agent/src/mcp.rs new file mode 100644 index 00000000000000..1be29c6239d5a5 --- /dev/null +++ b/crates/cherrypick_agent/src/mcp.rs @@ -0,0 +1,333 @@ +use std::collections::HashMap; +use std::process::Stdio; + +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, Command}; +use tokio::time::{timeout, Duration}; + +use crate::error::{AgentError, Result}; +use crate::provider::types::ToolDefinition; + +const REQUEST_TIMEOUT: Duration = Duration::from_secs(30); +const MAX_MESSAGE_SIZE: usize = 1_048_576; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpClientConfig { + pub name: String, + pub command: String, + pub args: Vec, + pub env: HashMap, +} + +#[derive(Serialize)] +struct JsonRpcRequest { + jsonrpc: String, + id: u64, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option, +} + +#[derive(Deserialize)] +struct JsonRpcResponse { + id: u64, + #[serde(default)] + result: Option, + #[serde(default)] + error: Option, +} + +#[derive(Deserialize)] +struct JsonRpcError { + code: i64, + message: String, +} + +pub struct McpClient { + config: McpClientConfig, + process: Option, + next_id: u64, + discovered_tools: Vec, +} + +impl McpClient { + pub fn new(config: McpClientConfig) -> Self { + Self { + config, + process: None, + next_id: 1, + discovered_tools: Vec::new(), + } + } + + pub async fn start(&mut self) -> Result<()> { + let mut cmd = Command::new(&self.config.command); + cmd.args(&self.config.args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::null()); + + for (k, v) in &self.config.env { + cmd.env(k, v); + } + + let child = cmd.spawn().map_err(|e| { + AgentError::Mcp(format!( + "Failed to start MCP server '{}': {}", + self.config.name, e + )) + })?; + + self.process = Some(child); + self.discover_tools().await?; + Ok(()) + } + + async fn discover_tools(&mut self) -> Result<()> { + let response = self.send_request("tools/list", None).await?; + + if let Some(result) = response { + if let Some(tools) = result.get("tools").and_then(|t| t.as_array()) { + self.discovered_tools = tools + .iter() + .filter_map(|t| { + Some(ToolDefinition { + name: format!( + "mcp_{}__{}", + self.config.name, + t.get("name")?.as_str()? + ), + description: t + .get("description") + .and_then(|d| d.as_str()) + .unwrap_or("") + .to_string(), + input_schema: t + .get("inputSchema") + .cloned() + .unwrap_or(serde_json::json!({"type": "object"})), + }) + }) + .collect(); + } + } + + Ok(()) + } + + pub async fn call_tool( + &mut self, + tool_name: &str, + arguments: serde_json::Value, + ) -> Result { + let params = serde_json::json!({ + "name": tool_name, + "arguments": arguments, + }); + + let response = self.send_request("tools/call", Some(params)).await?; + + match response { + Some(result) => { + if let Some(content) = result.get("content").and_then(|c| c.as_array()) { + let text_parts: Vec<&str> = content + .iter() + .filter_map(|c| c.get("text").and_then(|t| t.as_str())) + .collect(); + Ok(text_parts.join("\n")) + } else { + Ok(result.to_string()) + } + } + None => Ok(String::new()), + } + } + + async fn send_request( + &mut self, + method: &str, + params: Option, + ) -> Result> { + let process = self + .process + .as_mut() + .ok_or_else(|| AgentError::Mcp("MCP server not running".into()))?; + + let id = self.next_id; + self.next_id += 1; + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params, + }; + + let mut request_bytes = serde_json::to_vec(&request)?; + request_bytes.push(b'\n'); + + if request_bytes.len() > MAX_MESSAGE_SIZE { + return Err(AgentError::Mcp("Request too large".into())); + } + + let stdin = process + .stdin + .as_mut() + .ok_or_else(|| AgentError::Mcp("No stdin".into()))?; + stdin.write_all(&request_bytes).await?; + stdin.flush().await?; + + let stdout = process + .stdout + .as_mut() + .ok_or_else(|| AgentError::Mcp("No stdout".into()))?; + let mut reader = BufReader::new(stdout); + let mut line = String::new(); + + let read_result = timeout(REQUEST_TIMEOUT, reader.read_line(&mut line)).await; + + match read_result { + Ok(Ok(0)) => Err(AgentError::Mcp("Server closed connection".into())), + Ok(Ok(_)) => { + if line.len() > MAX_MESSAGE_SIZE { + return Err(AgentError::Mcp("Response too large".into())); + } + let response: JsonRpcResponse = serde_json::from_str(&line) + .map_err(|e| AgentError::Mcp(format!("Invalid response: {e}")))?; + if let Some(error) = response.error { + Err(AgentError::Mcp(format!( + "RPC error {}: {}", + error.code, error.message + ))) + } else { + Ok(response.result) + } + } + Ok(Err(e)) => Err(AgentError::Mcp(format!("Read error: {e}"))), + Err(_) => Err(AgentError::Mcp("Request timed out".into())), + } + } + + pub fn tools(&self) -> &[ToolDefinition] { + &self.discovered_tools + } + + pub async fn stop(&mut self) -> Result<()> { + if let Some(ref mut process) = self.process { + let _ = process.kill().await; + self.process = None; + } + Ok(()) + } + + pub fn is_running(&self) -> bool { + self.process.is_some() + } + + pub fn name(&self) -> &str { + &self.config.name + } +} + +impl Drop for McpClient { + fn drop(&mut self) { + if let Some(ref mut process) = self.process { + let _ = process.start_kill(); + } + } +} + +pub struct McpManager { + clients: HashMap, +} + +impl McpManager { + pub fn new() -> Self { + Self { + clients: HashMap::new(), + } + } + + pub fn add_server(&mut self, config: McpClientConfig) { + let name = config.name.clone(); + self.clients.insert(name, McpClient::new(config)); + } + + pub async fn start_all(&mut self) -> Vec<(String, Result<()>)> { + let mut results = Vec::new(); + let names: Vec = self.clients.keys().cloned().collect(); + for name in names { + if let Some(client) = self.clients.get_mut(&name) { + let result = client.start().await; + results.push((name, result)); + } + } + results + } + + pub async fn stop_all(&mut self) { + for client in self.clients.values_mut() { + let _ = client.stop().await; + } + } + + pub fn all_tools(&self) -> Vec { + self.clients + .values() + .flat_map(|c| c.tools().iter().cloned()) + .collect() + } + + pub fn get_client(&mut self, name: &str) -> Option<&mut McpClient> { + self.clients.get_mut(name) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mcp_manager_creation() { + let mgr = McpManager::new(); + assert!(mgr.all_tools().is_empty()); + } + + #[test] + fn add_server_config() { + let mut mgr = McpManager::new(); + mgr.add_server(McpClientConfig { + name: "test".into(), + command: "echo".into(), + args: vec![], + env: HashMap::new(), + }); + assert!(mgr.get_client("test").is_some()); + } + + #[test] + fn client_not_running_initially() { + let client = McpClient::new(McpClientConfig { + name: "test".into(), + command: "echo".into(), + args: vec![], + env: HashMap::new(), + }); + assert!(!client.is_running()); + assert_eq!(client.name(), "test"); + } + + #[test] + fn json_rpc_request_serialization() { + let req = JsonRpcRequest { + jsonrpc: "2.0".into(), + id: 1, + method: "tools/list".into(), + params: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("tools/list")); + assert!(!json.contains("params")); + } +} diff --git a/crates/cherrypick_agent/src/provider/anthropic.rs b/crates/cherrypick_agent/src/provider/anthropic.rs new file mode 100644 index 00000000000000..554cc66041c147 --- /dev/null +++ b/crates/cherrypick_agent/src/provider/anthropic.rs @@ -0,0 +1,264 @@ +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; + +use crate::error::{AgentError, Result}; +use super::LlmProvider; +use super::types::{CompletionRequest, StreamChunk, Usage}; + +const ANTHROPIC_API_URL: &str = "https://api.anthropic.com/v1/messages"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; + +pub struct AnthropicProvider { + api_key: String, + client: Client, +} + +impl AnthropicProvider { + pub fn new(api_key: String) -> Self { + Self { + api_key, + client: Client::new(), + } + } +} + +#[derive(Serialize)] +struct AnthropicRequest { + model: String, + messages: serde_json::Value, + max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + stream: bool, +} + +#[derive(Deserialize, Debug)] +struct SseEvent { + #[serde(rename = "type")] + event_type: String, + #[serde(default)] + delta: Option, + #[serde(default)] + content_block: Option, + #[serde(default)] + usage: Option, + #[serde(default)] + error: Option, + #[serde(default)] + index: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct SseDelta { + #[serde(rename = "type", default)] + delta_type: String, + #[serde(default)] + text: Option, + #[serde(default)] + partial_json: Option, +} + +#[derive(Deserialize, Debug)] +struct ContentBlock { + #[serde(rename = "type")] + block_type: String, + #[serde(default)] + id: Option, + #[serde(default)] + name: Option, + #[serde(default)] + text: Option, +} + +#[derive(Deserialize, Debug)] +struct SseUsage { + input_tokens: u32, + output_tokens: u32, +} + +#[derive(Deserialize, Debug)] +struct SseError { + message: String, +} + +#[async_trait] +impl LlmProvider for AnthropicProvider { + fn name(&self) -> &str { + "anthropic" + } + + async fn stream_completion( + &self, + request: CompletionRequest, + tx: mpsc::UnboundedSender, + ) -> Result<()> { + let tools_json: Vec = request + .tools + .iter() + .map(|t| { + serde_json::json!({ + "name": t.name, + "description": t.description, + "input_schema": t.input_schema, + }) + }) + .collect(); + + let messages_json = serde_json::to_value(&request.messages)?; + + let body = AnthropicRequest { + model: request.model, + messages: messages_json, + max_tokens: request.max_tokens, + system: request.system, + tools: tools_json, + temperature: request.temperature, + stream: true, + }; + + let body_json = serde_json::to_string(&body)?; + let response = self + .client + .post(ANTHROPIC_API_URL) + .header("x-api-key", &self.api_key) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json") + .body(body_json) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let error_text = response.text().await.unwrap_or_default(); + if status.as_u16() == 429 { + return Err(AgentError::RateLimited { + retry_after_secs: 60, + }); + } + if status.as_u16() == 401 { + return Err(AgentError::Provider("Invalid API key".into())); + } + return Err(AgentError::Provider(format!( + "HTTP {}: {}", + status, error_text + ))); + } + + let text = response.text().await?; + + for line in text.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with(':') { + continue; + } + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + let _ = tx.send(StreamChunk::Done); + break; + } + if let Ok(event) = serde_json::from_str::(data) { + match event.event_type.as_str() { + "content_block_start" => { + if let Some(block) = &event.content_block { + if block.block_type == "tool_use" { + let _ = tx.send(StreamChunk::ToolCallStart { + id: block.id.clone().unwrap_or_default(), + name: block.name.clone().unwrap_or_default(), + }); + } + if block.block_type == "text" { + if let Some(text) = &block.text { + if !text.is_empty() { + let _ = tx.send(StreamChunk::TextDelta(text.clone())); + } + } + } + } + } + "content_block_delta" => { + if let Some(delta) = &event.delta { + if let Some(text) = &delta.text { + let _ = tx.send(StreamChunk::TextDelta(text.clone())); + } + if let Some(json) = &delta.partial_json { + let _ = tx.send(StreamChunk::ToolCallDelta(json.clone())); + } + } + } + "content_block_stop" => { + let _ = tx.send(StreamChunk::ToolCallEnd); + } + "message_delta" => { + if let Some(usage) = &event.usage { + let _ = tx.send(StreamChunk::Usage(Usage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + })); + } + } + "message_start" => { + if let Some(usage) = &event.usage { + let _ = tx.send(StreamChunk::Usage(Usage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + })); + } + } + "message_stop" => { + let _ = tx.send(StreamChunk::Done); + } + "error" => { + if let Some(err) = &event.error { + let _ = tx.send(StreamChunk::Error(err.message.clone())); + } + } + _ => {} + } + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn provider_name() { + let provider = AnthropicProvider::new("test-key".into()); + assert_eq!(provider.name(), "anthropic"); + } + + #[test] + fn sse_event_deserialization() { + let json = r#"{"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello"}}"#; + let event: SseEvent = serde_json::from_str(json).unwrap(); + assert_eq!(event.event_type, "content_block_delta"); + assert_eq!(event.delta.as_ref().unwrap().text, Some("Hello".into())); + } + + #[test] + fn sse_tool_use_start() { + let json = r#"{"type":"content_block_start","content_block":{"type":"tool_use","id":"toolu_123","name":"read_file"}}"#; + let event: SseEvent = serde_json::from_str(json).unwrap(); + let block = event.content_block.unwrap(); + assert_eq!(block.block_type, "tool_use"); + assert_eq!(block.name, Some("read_file".into())); + } + + #[test] + fn sse_error_event() { + let json = r#"{"type":"error","error":{"message":"rate limited"}}"#; + let event: SseEvent = serde_json::from_str(json).unwrap(); + assert_eq!(event.error.unwrap().message, "rate limited"); + } +} diff --git a/crates/cherrypick_agent/src/provider/mod.rs b/crates/cherrypick_agent/src/provider/mod.rs new file mode 100644 index 00000000000000..c76b7ebdf3b0a9 --- /dev/null +++ b/crates/cherrypick_agent/src/provider/mod.rs @@ -0,0 +1,19 @@ +pub mod types; +pub mod anthropic; + +use async_trait::async_trait; +use tokio::sync::mpsc; + +use crate::error::Result; +use types::{CompletionRequest, StreamChunk}; + +#[async_trait] +pub trait LlmProvider: Send + Sync { + fn name(&self) -> &str; + + async fn stream_completion( + &self, + request: CompletionRequest, + tx: mpsc::UnboundedSender, + ) -> Result<()>; +} diff --git a/crates/cherrypick_agent/src/provider/types.rs b/crates/cherrypick_agent/src/provider/types.rs new file mode 100644 index 00000000000000..065570bf89d910 --- /dev/null +++ b/crates/cherrypick_agent/src/provider/types.rs @@ -0,0 +1,200 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Role { + #[serde(rename = "user")] + User, + #[serde(rename = "assistant")] + Assistant, + #[serde(rename = "system")] + System, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum MessageContent { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + is_error: bool, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, +} + +impl Message { + pub fn user(text: &str) -> Self { + Self { + role: Role::User, + content: vec![MessageContent::Text { + text: text.to_string(), + }], + } + } + + pub fn assistant(text: &str) -> Self { + Self { + role: Role::Assistant, + content: vec![MessageContent::Text { + text: text.to_string(), + }], + } + } + + pub fn tool_result(tool_use_id: &str, content: &str, is_error: bool) -> Self { + Self { + role: Role::User, + content: vec![MessageContent::ToolResult { + tool_use_id: tool_use_id.to_string(), + content: content.to_string(), + is_error, + }], + } + } + + pub fn text_content(&self) -> Option<&str> { + self.content.iter().find_map(|c| match c { + MessageContent::Text { text } => Some(text.as_str()), + _ => None, + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompletionRequest { + pub model: String, + pub messages: Vec, + pub system: Option, + pub tools: Vec, + pub max_tokens: u32, + pub temperature: Option, +} + +#[derive(Debug, Clone)] +pub enum StreamChunk { + TextDelta(String), + ToolCallStart { + id: String, + name: String, + }, + ToolCallDelta(String), + ToolCallEnd, + Usage(Usage), + Done, + Error(String), +} + +#[derive(Debug, Clone, Default)] +pub struct Usage { + pub input_tokens: u32, + pub output_tokens: u32, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum RiskLevel { + ReadOnly, + WriteLocal, + WriteRemote, + Network, + Destructive, +} + +impl RiskLevel { + pub fn requires_confirmation(&self) -> bool { + *self >= RiskLevel::WriteLocal + } + + pub fn label(&self) -> &'static str { + match self { + Self::ReadOnly => "Read-only", + Self::WriteLocal => "Write (local)", + Self::WriteRemote => "Write (remote)", + Self::Network => "Network", + Self::Destructive => "Destructive", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn user_message_creation() { + let msg = Message::user("hello"); + assert_eq!(msg.role, Role::User); + assert_eq!(msg.text_content(), Some("hello")); + } + + #[test] + fn assistant_message_creation() { + let msg = Message::assistant("response"); + assert_eq!(msg.role, Role::Assistant); + assert_eq!(msg.text_content(), Some("response")); + } + + #[test] + fn tool_result_message() { + let msg = Message::tool_result("id-1", "output", false); + assert_eq!(msg.role, Role::User); + match &msg.content[0] { + MessageContent::ToolResult { + tool_use_id, + is_error, + .. + } => { + assert_eq!(tool_use_id, "id-1"); + assert!(!is_error); + } + _ => panic!("Expected ToolResult"), + } + } + + #[test] + fn risk_level_ordering() { + assert!(RiskLevel::ReadOnly < RiskLevel::WriteLocal); + assert!(RiskLevel::WriteLocal < RiskLevel::Destructive); + } + + #[test] + fn risk_requires_confirmation() { + assert!(!RiskLevel::ReadOnly.requires_confirmation()); + assert!(RiskLevel::WriteLocal.requires_confirmation()); + assert!(RiskLevel::Destructive.requires_confirmation()); + } + + #[test] + fn message_serialization() { + let msg = Message::user("test"); + let json = serde_json::to_string(&msg).unwrap(); + let deserialized: Message = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.text_content(), Some("test")); + } +} diff --git a/crates/cherrypick_agent/src/skills/mod.rs b/crates/cherrypick_agent/src/skills/mod.rs new file mode 100644 index 00000000000000..374324ed1f2cbb --- /dev/null +++ b/crates/cherrypick_agent/src/skills/mod.rs @@ -0,0 +1,247 @@ +use std::collections::HashMap; +use std::path::Path; + +use crate::error::{AgentError, Result}; + +#[derive(Debug, Clone)] +pub struct Skill { + pub name: String, + pub description: String, + pub trigger: Option, + pub allowed_tools: Vec, + pub prompt_template: String, +} + +pub struct SkillLoader { + skills: HashMap, + builtin: HashMap, +} + +impl SkillLoader { + pub fn new() -> Self { + let mut loader = Self { + skills: HashMap::new(), + builtin: HashMap::new(), + }; + loader.register_builtins(); + loader + } + + fn register_builtins(&mut self) { + self.builtin.insert( + "explain-commit".into(), + Skill { + name: "explain-commit".into(), + description: "Explain what a commit does".into(), + trigger: Some("/explain".into()), + allowed_tools: vec!["git_log".into(), "git_diff".into(), "read_file".into()], + prompt_template: "Explain the most recent commit in this repository. Use git_log to find it, then git_diff to see the changes, and explain what was done and why.".into(), + }, + ); + + self.builtin.insert( + "review-changes".into(), + Skill { + name: "review-changes".into(), + description: "Review current staged/unstaged changes".into(), + trigger: Some("/review".into()), + allowed_tools: vec!["git_status".into(), "git_diff".into(), "read_file".into()], + prompt_template: "Review the current changes in this repository. Check git_status, then review the diffs. Look for bugs, style issues, and potential improvements.".into(), + }, + ); + + self.builtin.insert( + "generate-commit-message".into(), + Skill { + name: "generate-commit-message".into(), + description: "Generate a commit message for staged changes".into(), + trigger: Some("/commit-msg".into()), + allowed_tools: vec!["git_status".into(), "git_diff".into()], + prompt_template: "Generate a concise, conventional commit message for the currently staged changes. Use git_diff with staged=true to see what's staged.".into(), + }, + ); + + self.builtin.insert( + "find-related".into(), + Skill { + name: "find-related".into(), + description: "Find files related to a topic".into(), + trigger: Some("/find".into()), + allowed_tools: vec!["search_code".into(), "list_files".into(), "read_file".into()], + prompt_template: "Find all files related to the user's query. Use search_code and list_files to locate relevant code.".into(), + }, + ); + } + + pub fn load_from_directory(&mut self, dir: &Path) -> Result { + let mut count = 0; + let entries = std::fs::read_dir(dir).map_err(|e| AgentError::Io(e))?; + + for entry in entries { + let entry = entry.map_err(|e| AgentError::Io(e))?; + let path = entry.path(); + + if path.extension().and_then(|e| e.to_str()) == Some("md") { + if let Ok(skill) = parse_skill_file(&path) { + self.skills.insert(skill.name.clone(), skill); + count += 1; + } + } + } + + Ok(count) + } + + pub fn get(&self, name: &str) -> Option<&Skill> { + self.skills.get(name).or_else(|| self.builtin.get(name)) + } + + pub fn resolve_slash_command(&self, input: &str) -> Option<&Skill> { + let trigger = input.split_whitespace().next()?; + self.builtin + .values() + .chain(self.skills.values()) + .find(|s| s.trigger.as_deref() == Some(trigger)) + } + + pub fn list_all(&self) -> Vec<&Skill> { + let mut all: Vec<&Skill> = self.builtin.values().chain(self.skills.values()).collect(); + all.sort_by_key(|s| &s.name); + all + } +} + +fn parse_skill_file(path: &Path) -> Result { + let content = std::fs::read_to_string(path)?; + + let (frontmatter, body) = if content.starts_with("---") { + let rest = &content[3..]; + if let Some(end) = rest.find("---") { + let fm = &rest[..end].trim(); + let body = &rest[end + 3..].trim(); + (Some(fm.to_string()), body.to_string()) + } else { + (None, content) + } + } else { + (None, content) + }; + + let mut name = path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("unknown") + .to_string(); + let mut description = String::new(); + let mut trigger = None; + let mut allowed_tools = Vec::new(); + + if let Some(fm) = frontmatter { + for line in fm.lines() { + let line = line.trim(); + if let Some(val) = line.strip_prefix("name:") { + name = val.trim().trim_matches('"').to_string(); + } else if let Some(val) = line.strip_prefix("description:") { + description = val.trim().trim_matches('"').to_string(); + } else if let Some(val) = line.strip_prefix("trigger:") { + trigger = Some(val.trim().trim_matches('"').to_string()); + } else if let Some(val) = line.strip_prefix("tools:") { + allowed_tools = val + .trim() + .trim_start_matches('[') + .trim_end_matches(']') + .split(',') + .map(|s| s.trim().trim_matches('"').to_string()) + .filter(|s| !s.is_empty()) + .collect(); + } + } + } + + Ok(Skill { + name, + description, + trigger, + allowed_tools, + prompt_template: body, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builtin_skills_loaded() { + let loader = SkillLoader::new(); + assert!(loader.get("explain-commit").is_some()); + assert!(loader.get("review-changes").is_some()); + assert!(loader.get("generate-commit-message").is_some()); + assert!(loader.get("find-related").is_some()); + } + + #[test] + fn slash_command_resolution() { + let loader = SkillLoader::new(); + let skill = loader.resolve_slash_command("/explain something"); + assert!(skill.is_some()); + assert_eq!(skill.unwrap().name, "explain-commit"); + } + + #[test] + fn unknown_slash_command() { + let loader = SkillLoader::new(); + assert!(loader.resolve_slash_command("/unknown").is_none()); + } + + #[test] + fn list_all_skills() { + let loader = SkillLoader::new(); + let all = loader.list_all(); + assert!(all.len() >= 4); + } + + #[test] + fn parse_skill_file_from_disk() { + let tmp = tempfile::tempdir().unwrap(); + let skill_path = tmp.path().join("test-skill.md"); + std::fs::write( + &skill_path, + r#"--- +name: test-skill +description: A test skill +trigger: /test +tools: [read_file, git_status] +--- +Do something useful."#, + ) + .unwrap(); + + let skill = parse_skill_file(&skill_path).unwrap(); + assert_eq!(skill.name, "test-skill"); + assert_eq!(skill.trigger, Some("/test".into())); + assert_eq!(skill.allowed_tools.len(), 2); + assert!(skill.prompt_template.contains("something useful")); + } + + #[test] + fn load_skills_from_directory() { + let tmp = tempfile::tempdir().unwrap(); + std::fs::write( + tmp.path().join("skill-a.md"), + "---\nname: a\n---\nSkill A", + ) + .unwrap(); + std::fs::write( + tmp.path().join("skill-b.md"), + "---\nname: b\n---\nSkill B", + ) + .unwrap(); + + let mut loader = SkillLoader::new(); + let count = loader.load_from_directory(tmp.path()).unwrap(); + assert_eq!(count, 2); + assert!(loader.get("a").is_some()); + assert!(loader.get("b").is_some()); + } +} diff --git a/crates/cherrypick_agent/src/tools/git_tools.rs b/crates/cherrypick_agent/src/tools/git_tools.rs new file mode 100644 index 00000000000000..422fb522809167 --- /dev/null +++ b/crates/cherrypick_agent/src/tools/git_tools.rs @@ -0,0 +1,517 @@ +use std::path::Path; + +use async_trait::async_trait; +use serde_json::json; + +use crate::error::{AgentError, Result}; +use crate::provider::types::{RiskLevel, ToolDefinition}; +use super::ToolHandler; +use super::safe_path::resolve_safe_path; + +pub struct ReadFileTool; + +#[async_trait] +impl ToolHandler for ReadFileTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "read_file".to_string(), + description: "Read the contents of a file".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "path": { "type": "string", "description": "Relative path to the file" } + }, + "required": ["path"] + }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::ReadOnly } + + async fn execute(&self, args: serde_json::Value, repo_path: &Path) -> Result { + let path = args["path"].as_str().ok_or_else(|| AgentError::ToolExecution("Missing 'path' argument".into()))?; + + if super::is_sensitive_path(path) { + return Err(AgentError::ToolExecution("Cannot read sensitive files".into())); + } + + let full_path = resolve_safe_path(repo_path, path)?; + + std::fs::read_to_string(&full_path) + .map_err(|e| AgentError::ToolExecution(format!("Failed to read file: {e}"))) + } + + fn preview(&self, args: &serde_json::Value) -> String { + format!("Read file: {}", args["path"].as_str().unwrap_or("?")) + } +} + +pub struct ListFilesTool; + +#[async_trait] +impl ToolHandler for ListFilesTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "list_files".to_string(), + description: "List files in a directory".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "path": { "type": "string", "description": "Relative directory path" } + } + }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::ReadOnly } + + async fn execute(&self, args: serde_json::Value, repo_path: &Path) -> Result { + let dir_path = args["path"].as_str().unwrap_or("."); + let full_path = resolve_safe_path(repo_path, dir_path)?; + + let mut entries = Vec::new(); + let read_dir = std::fs::read_dir(&full_path) + .map_err(|e| AgentError::ToolExecution(format!("Failed to list directory: {e}")))?; + + for entry in read_dir { + let entry = entry.map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let name = entry.file_name().to_string_lossy().to_string(); + let ft = entry.file_type().map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let prefix = if ft.is_dir() { "d " } else { "f " }; + entries.push(format!("{prefix}{name}")); + } + + entries.sort(); + Ok(entries.join("\n")) + } + + fn preview(&self, args: &serde_json::Value) -> String { + format!("List files in: {}", args["path"].as_str().unwrap_or(".")) + } +} + +pub struct GitStatusTool; + +#[async_trait] +impl ToolHandler for GitStatusTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "git_status".to_string(), + description: "Show the working tree status of the git repository".to_string(), + input_schema: json!({ "type": "object", "properties": {} }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::ReadOnly } + + async fn execute(&self, _args: serde_json::Value, repo_path: &Path) -> Result { + let repo = git2::Repository::discover(repo_path) + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + let statuses = repo.statuses(Some( + git2::StatusOptions::new() + .include_untracked(true) + .recurse_untracked_dirs(true), + )).map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + let mut output = Vec::new(); + for entry in statuses.iter() { + let path = entry.path().unwrap_or("?"); + let status = entry.status(); + let indicator = format_status(status); + output.push(format!("{indicator} {path}")); + } + + if output.is_empty() { + Ok("Working tree clean".to_string()) + } else { + Ok(output.join("\n")) + } + } + + fn preview(&self, _args: &serde_json::Value) -> String { + "Show git status".to_string() + } +} + +fn format_status(status: git2::Status) -> String { + let index = if status.is_index_new() { "A" } + else if status.is_index_modified() { "M" } + else if status.is_index_deleted() { "D" } + else if status.is_index_renamed() { "R" } + else { " " }; + + let wt = if status.is_wt_new() { "?" } + else if status.is_wt_modified() { "M" } + else if status.is_wt_deleted() { "D" } + else { " " }; + + format!("{index}{wt}") +} + +pub struct GitDiffTool; + +#[async_trait] +impl ToolHandler for GitDiffTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "git_diff".to_string(), + description: "Show changes in the working directory or between commits".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "staged": { "type": "boolean", "description": "Show staged changes only" } + } + }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::ReadOnly } + + async fn execute(&self, args: serde_json::Value, repo_path: &Path) -> Result { + let staged = args["staged"].as_bool().unwrap_or(false); + let repo = git2::Repository::discover(repo_path) + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + let diff = if staged { + let head_tree = repo.head().ok().and_then(|h| h.peel_to_tree().ok()); + repo.diff_tree_to_index(head_tree.as_ref(), None, None) + } else { + repo.diff_index_to_workdir(None, None) + }.map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + let mut output = String::new(); + diff.print(git2::DiffFormat::Patch, |_delta, _hunk, line| { + let prefix = match line.origin() { + '+' => "+", + '-' => "-", + ' ' => " ", + _ => "", + }; + if let Ok(content) = std::str::from_utf8(line.content()) { + output.push_str(prefix); + output.push_str(content); + } + true + }).map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + if output.is_empty() { + Ok("No changes".to_string()) + } else { + Ok(output) + } + } + + fn preview(&self, args: &serde_json::Value) -> String { + if args["staged"].as_bool().unwrap_or(false) { + "Show staged diff".to_string() + } else { + "Show working directory diff".to_string() + } + } +} + +pub struct GitLogTool; + +#[async_trait] +impl ToolHandler for GitLogTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "git_log".to_string(), + description: "Show recent commit history".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "count": { "type": "integer", "description": "Number of commits to show (default 10)" } + } + }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::ReadOnly } + + async fn execute(&self, args: serde_json::Value, repo_path: &Path) -> Result { + let count = args["count"].as_u64().unwrap_or(10) as usize; + let repo = git2::Repository::discover(repo_path) + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + let mut revwalk = repo.revwalk().map_err(|e| AgentError::ToolExecution(e.to_string()))?; + revwalk.push_head().map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + let mut output = Vec::new(); + for (i, oid) in revwalk.enumerate() { + if i >= count { break; } + let oid = oid.map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let commit = repo.find_commit(oid).map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let short = &oid.to_string()[..7]; + let msg = commit.message().unwrap_or("").lines().next().unwrap_or(""); + let author = commit.author().name().unwrap_or("").to_string(); + output.push(format!("{short} {author}: {msg}")); + } + + Ok(output.join("\n")) + } + + fn preview(&self, args: &serde_json::Value) -> String { + let count = args["count"].as_u64().unwrap_or(10); + format!("Show last {count} commits") + } +} + +pub struct SearchCodeTool; + +#[async_trait] +impl ToolHandler for SearchCodeTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "search_code".to_string(), + description: "Search for a pattern in repository files".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "pattern": { "type": "string", "description": "Text pattern to search for" }, + "path": { "type": "string", "description": "Optional path filter" } + }, + "required": ["pattern"] + }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::ReadOnly } + + async fn execute(&self, args: serde_json::Value, repo_path: &Path) -> Result { + let pattern = args["pattern"].as_str() + .ok_or_else(|| AgentError::ToolExecution("Missing 'pattern'".into()))?; + let search_path = args["path"].as_str().unwrap_or("."); + let full_path = resolve_safe_path(repo_path, search_path)?; + + let mut results = Vec::new(); + search_recursive(&full_path, pattern, repo_path, &mut results, 100)?; + + if results.is_empty() { + Ok(format!("No matches found for '{pattern}'")) + } else { + Ok(results.join("\n")) + } + } + + fn preview(&self, args: &serde_json::Value) -> String { + format!("Search for: {}", args["pattern"].as_str().unwrap_or("?")) + } +} + +fn search_recursive( + dir: &Path, + pattern: &str, + repo_root: &Path, + results: &mut Vec, + max_results: usize, +) -> Result<()> { + if results.len() >= max_results { + return Ok(()); + } + + let entries = match std::fs::read_dir(dir) { + Ok(e) => e, + Err(_) => return Ok(()), + }; + + for entry in entries { + let entry = match entry { + Ok(e) => e, + Err(_) => continue, + }; + + let path = entry.path(); + let name = path.file_name().and_then(|n| n.to_str()).unwrap_or(""); + if name.starts_with('.') || name == "target" || name == "node_modules" { + continue; + } + + let metadata = match std::fs::symlink_metadata(&path) { + Ok(m) => m, + Err(_) => continue, + }; + if metadata.file_type().is_symlink() { + continue; + } + + if path.is_dir() { + search_recursive(&path, pattern, repo_root, results, max_results)?; + } else if path.is_file() { + if let Ok(content) = std::fs::read_to_string(&path) { + for (lineno, line) in content.lines().enumerate() { + if line.contains(pattern) { + let rel = path.strip_prefix(repo_root).unwrap_or(&path); + results.push(format!("{}:{}: {}", rel.display(), lineno + 1, line.trim())); + if results.len() >= max_results { + return Ok(()); + } + } + } + } + } + } + + Ok(()) +} + +pub struct StageFilesTool; + +#[async_trait] +impl ToolHandler for StageFilesTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "stage_files".to_string(), + description: "Stage files for commit".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "paths": { + "type": "array", + "items": { "type": "string" }, + "description": "Paths to stage" + } + }, + "required": ["paths"] + }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::WriteLocal } + + async fn execute(&self, args: serde_json::Value, repo_path: &Path) -> Result { + let paths: Vec = serde_json::from_value(args["paths"].clone()) + .map_err(|e| AgentError::ToolExecution(format!("Invalid paths: {e}")))?; + + let repo = git2::Repository::discover(repo_path) + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let mut index = repo.index() + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + for path in &paths { + if super::is_sensitive_path(path) { + return Err(AgentError::ToolExecution(format!("Cannot stage sensitive file: {path}"))); + } + index.add_path(Path::new(path)) + .map_err(|e| AgentError::ToolExecution(format!("Failed to stage {path}: {e}")))?; + } + + index.write().map_err(|e| AgentError::ToolExecution(e.to_string()))?; + Ok(format!("Staged {} file(s)", paths.len())) + } + + fn preview(&self, args: &serde_json::Value) -> String { + let paths: Vec<&str> = args["paths"] + .as_array() + .map(|a| a.iter().filter_map(|v| v.as_str()).collect()) + .unwrap_or_default(); + format!("Stage: {}", paths.join(", ")) + } +} + +pub struct CreateCommitTool; + +#[async_trait] +impl ToolHandler for CreateCommitTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "create_commit".to_string(), + description: "Create a git commit with staged changes".to_string(), + input_schema: json!({ + "type": "object", + "properties": { + "message": { "type": "string", "description": "Commit message" } + }, + "required": ["message"] + }), + } + } + + fn risk_level(&self) -> RiskLevel { RiskLevel::WriteLocal } + + async fn execute(&self, args: serde_json::Value, repo_path: &Path) -> Result { + let message = args["message"].as_str() + .ok_or_else(|| AgentError::ToolExecution("Missing 'message'".into()))?; + + let repo = git2::Repository::discover(repo_path) + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let sig = repo.signature() + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let mut index = repo.index() + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let tree_id = index.write_tree() + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + let tree = repo.find_tree(tree_id) + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + let parents: Vec = match repo.head() { + Ok(head) => vec![head.peel_to_commit() + .map_err(|e| AgentError::ToolExecution(e.to_string()))?], + Err(_) => vec![], + }; + let parent_refs: Vec<&git2::Commit> = parents.iter().collect(); + + let oid = repo.commit(Some("HEAD"), &sig, &sig, message, &tree, &parent_refs) + .map_err(|e| AgentError::ToolExecution(e.to_string()))?; + + Ok(format!("Created commit {}", &oid.to_string()[..7])) + } + + fn preview(&self, args: &serde_json::Value) -> String { + format!("Create commit: {}", args["message"].as_str().unwrap_or("?")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn read_file_blocks_sensitive() { + let tool = ReadFileTool; + let rt = tokio::runtime::Runtime::new().unwrap(); + let result = rt.block_on(tool.execute( + json!({"path": ".env"}), + Path::new("/tmp"), + )); + assert!(result.is_err()); + } + + #[test] + fn read_file_blocks_traversal() { + let tool = ReadFileTool; + let rt = tokio::runtime::Runtime::new().unwrap(); + let result = rt.block_on(tool.execute( + json!({"path": "../../etc/passwd"}), + Path::new("/tmp/repo"), + )); + assert!(result.is_err()); + } + + #[test] + fn format_status_flags() { + let status = git2::Status::INDEX_NEW; + let s = format_status(status); + assert!(s.starts_with('A')); + } + + #[test] + fn tool_definitions_valid() { + let tools: Vec> = vec![ + Box::new(ReadFileTool), + Box::new(ListFilesTool), + Box::new(GitStatusTool), + Box::new(GitDiffTool), + Box::new(GitLogTool), + Box::new(SearchCodeTool), + Box::new(StageFilesTool), + Box::new(CreateCommitTool), + ]; + for tool in &tools { + let def = tool.definition(); + assert!(!def.name.is_empty()); + assert!(!def.description.is_empty()); + } + } +} diff --git a/crates/cherrypick_agent/src/tools/mod.rs b/crates/cherrypick_agent/src/tools/mod.rs new file mode 100644 index 00000000000000..88d63e4e1e01d5 --- /dev/null +++ b/crates/cherrypick_agent/src/tools/mod.rs @@ -0,0 +1,171 @@ +pub mod git_tools; +pub mod safe_path; + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +use async_trait::async_trait; +use tokio::sync::oneshot; + +use crate::error::{AgentError, Result}; +use crate::provider::types::{RiskLevel, ToolCall, ToolDefinition}; + +const OUTPUT_TRUNCATION_LIMIT: usize = 50_000; + +static SENSITIVE_PATTERNS: &[&str] = &[ + ".env", + ".env.local", + ".env.production", + "credentials.json", + "id_rsa", + "id_ed25519", + ".ssh/config", + ".netrc", + ".npmrc", + "token", + "secret", +]; + +#[async_trait] +pub trait ToolHandler: Send + Sync { + fn definition(&self) -> ToolDefinition; + fn risk_level(&self) -> RiskLevel; + async fn execute(&self, arguments: serde_json::Value, repo_path: &Path) -> Result; + fn preview(&self, arguments: &serde_json::Value) -> String; +} + +pub struct ConfirmationRequest { + pub tool_name: String, + pub risk_level: RiskLevel, + pub preview: String, + pub response: oneshot::Sender, +} + +pub struct ToolExecutor { + tools: HashMap>, + excluded_repos: Vec, +} + +impl ToolExecutor { + pub fn new() -> Self { + let mut executor = Self { + tools: HashMap::new(), + excluded_repos: Vec::new(), + }; + executor.register_builtin_tools(); + executor + } + + fn register_builtin_tools(&mut self) { + self.register(Box::new(git_tools::ReadFileTool)); + self.register(Box::new(git_tools::ListFilesTool)); + self.register(Box::new(git_tools::GitStatusTool)); + self.register(Box::new(git_tools::GitDiffTool)); + self.register(Box::new(git_tools::GitLogTool)); + self.register(Box::new(git_tools::SearchCodeTool)); + self.register(Box::new(git_tools::StageFilesTool)); + self.register(Box::new(git_tools::CreateCommitTool)); + } + + pub fn register(&mut self, handler: Box) { + let name = handler.definition().name.clone(); + self.tools.insert(name, handler); + } + + pub fn definitions(&self) -> Vec { + self.tools.values().map(|h| h.definition()).collect() + } + + pub fn exclude_repo(&mut self, path: PathBuf) { + self.excluded_repos.push(path); + } + + pub fn is_repo_excluded(&self, path: &Path) -> bool { + self.excluded_repos.iter().any(|p| path.starts_with(p)) + } + + pub async fn execute( + &self, + call: &ToolCall, + repo_path: &Path, + ) -> Result { + let handler = self + .tools + .get(&call.name) + .ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?; + + if self.is_repo_excluded(repo_path) { + return Err(AgentError::ToolExecution(format!( + "Repository is excluded from tool operations" + ))); + } + + let mut output = handler.execute(call.arguments.clone(), repo_path).await?; + + if output.len() > OUTPUT_TRUNCATION_LIMIT { + output.truncate(OUTPUT_TRUNCATION_LIMIT); + output.push_str("\n... (output truncated)"); + } + + Ok(output) + } + + pub fn risk_level(&self, tool_name: &str) -> Option { + self.tools.get(tool_name).map(|h| h.risk_level()) + } + + pub fn preview(&self, call: &ToolCall) -> Option { + self.tools.get(&call.name).map(|h| h.preview(&call.arguments)) + } +} + +pub fn is_sensitive_path(path: &str) -> bool { + let lower = path.to_lowercase(); + SENSITIVE_PATTERNS.iter().any(|p| lower.contains(p)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sensitive_path_detection() { + assert!(is_sensitive_path(".env")); + assert!(is_sensitive_path("path/to/.env.local")); + assert!(is_sensitive_path("credentials.json")); + assert!(is_sensitive_path("my_secret_file")); + assert!(!is_sensitive_path("src/main.rs")); + assert!(!is_sensitive_path("README.md")); + } + + #[test] + fn executor_has_builtin_tools() { + let executor = ToolExecutor::new(); + let defs = executor.definitions(); + assert!(!defs.is_empty()); + assert!(defs.iter().any(|d| d.name == "read_file")); + assert!(defs.iter().any(|d| d.name == "git_status")); + } + + #[test] + fn repo_exclusion() { + let mut executor = ToolExecutor::new(); + executor.exclude_repo(PathBuf::from("/excluded/repo")); + assert!(executor.is_repo_excluded(Path::new("/excluded/repo/subdir"))); + assert!(!executor.is_repo_excluded(Path::new("/other/repo"))); + } + + #[test] + fn risk_level_lookup() { + let executor = ToolExecutor::new(); + assert_eq!( + executor.risk_level("read_file"), + Some(RiskLevel::ReadOnly) + ); + assert_eq!( + executor.risk_level("create_commit"), + Some(RiskLevel::WriteLocal) + ); + assert!(executor.risk_level("nonexistent").is_none()); + } +} diff --git a/crates/cherrypick_agent/src/tools/safe_path.rs b/crates/cherrypick_agent/src/tools/safe_path.rs new file mode 100644 index 00000000000000..fc520259591a66 --- /dev/null +++ b/crates/cherrypick_agent/src/tools/safe_path.rs @@ -0,0 +1,108 @@ +use std::path::{Path, PathBuf}; + +use crate::error::{AgentError, Result}; + +pub fn resolve_safe_path(repo_root: &Path, relative_path: &str) -> Result { + if relative_path.is_empty() { + return Ok(repo_root.to_path_buf()); + } + + let canonical_root = repo_root + .canonicalize() + .map_err(|e| AgentError::ToolExecution(format!("Cannot canonicalize repo root: {e}")))?; + + let joined = canonical_root.join(relative_path); + + let resolved = if joined.exists() { + joined.canonicalize().map_err(|e| { + AgentError::ToolExecution(format!("Cannot canonicalize path: {e}")) + })? + } else { + let mut ancestor = joined.as_path(); + loop { + if let Some(parent) = ancestor.parent() { + if parent.exists() { + let canonical_parent = parent.canonicalize().map_err(|e| { + AgentError::ToolExecution(format!("Cannot canonicalize: {e}")) + })?; + let remainder = joined.strip_prefix(parent).unwrap_or(Path::new("")); + break canonical_parent.join(remainder); + } + ancestor = parent; + } else { + return Err(AgentError::ToolExecution( + "Path traversal detected: cannot resolve path".into(), + )); + } + } + }; + + if !resolved.starts_with(&canonical_root) { + return Err(AgentError::ToolExecution( + "Path traversal detected: resolved path escapes repository".into(), + )); + } + + if resolved.is_symlink() { + let link_target = std::fs::read_link(&resolved).map_err(|e| { + AgentError::ToolExecution(format!("Cannot read symlink: {e}")) + })?; + let absolute_target = if link_target.is_absolute() { + link_target + } else { + resolved.parent().unwrap_or(&canonical_root).join(&link_target) + }; + let canonical_target = absolute_target.canonicalize().map_err(|e| { + AgentError::ToolExecution(format!("Cannot canonicalize symlink target: {e}")) + })?; + if !canonical_target.starts_with(&canonical_root) { + return Err(AgentError::ToolExecution( + "Path traversal detected: symlink escapes repository".into(), + )); + } + } + + Ok(resolved) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simple_path_resolves() { + let tmp = tempfile::tempdir().unwrap(); + let file = tmp.path().join("test.txt"); + std::fs::write(&file, "hello").unwrap(); + let result = resolve_safe_path(tmp.path(), "test.txt"); + assert!(result.is_ok()); + } + + #[test] + fn dot_dot_traversal_rejected() { + let tmp = tempfile::tempdir().unwrap(); + std::fs::write(tmp.path().join("a.txt"), "a").unwrap(); + let result = resolve_safe_path(tmp.path(), "../../etc/passwd"); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("traversal")); + } + + #[test] + fn symlink_escape_rejected() { + let tmp = tempfile::tempdir().unwrap(); + let link_path = tmp.path().join("escape_link"); + #[cfg(unix)] + { + std::os::unix::fs::symlink("/etc", &link_path).unwrap(); + let result = resolve_safe_path(tmp.path(), "escape_link"); + assert!(result.is_err()); + } + } + + #[test] + fn empty_path_returns_root() { + let tmp = tempfile::tempdir().unwrap(); + let result = resolve_safe_path(tmp.path(), ""); + assert!(result.is_ok()); + } +} diff --git a/crates/cherrypick_pr/Cargo.toml b/crates/cherrypick_pr/Cargo.toml new file mode 100644 index 00000000000000..d75d70edcfddf5 --- /dev/null +++ b/crates/cherrypick_pr/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "cherrypick_pr" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/lib.rs" + +[dependencies] +git2.workspace = true +rusqlite.workspace = true +tokio-rusqlite.workspace = true +tokio = { workspace = true, features = ["rt", "sync", "macros", "time"] } +thiserror.workspace = true +chrono.workspace = true +lru.workspace = true +notify.workspace = true +serde.workspace = true +serde_json.workspace = true +sha2.workspace = true + +[dev-dependencies] +tempfile.workspace = true +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/crates/cherrypick_pr/src/diff_service.rs b/crates/cherrypick_pr/src/diff_service.rs new file mode 100644 index 00000000000000..58d1fb1d63fc9c --- /dev/null +++ b/crates/cherrypick_pr/src/diff_service.rs @@ -0,0 +1,237 @@ +use std::num::NonZeroUsize; +use std::path::Path; + +use lru::LruCache; + +use crate::error::Result; +use crate::types::FileContent; + +#[derive(Debug, Clone)] +pub struct DiffFileEntry { + pub path: String, + pub status: char, + pub insertions: usize, + pub deletions: usize, + pub is_binary: bool, + pub is_lfs: bool, + pub old_path: Option, +} + +#[derive(Debug, Clone)] +pub struct BranchDiff { + pub source_oid: String, + pub target_oid: String, + pub merge_base_oid: String, + pub files: Vec, + pub total_insertions: usize, + pub total_deletions: usize, +} + +pub struct DiffService { + cache: LruCache<(String, String), BranchDiff>, +} + +impl DiffService { + pub fn new(cache_size: usize) -> Self { + Self { + cache: LruCache::new(NonZeroUsize::new(cache_size.max(1)).unwrap()), + } + } + + pub fn get_branch_diff( + &mut self, + repo_path: &Path, + source_oid: &str, + target_oid: &str, + ) -> Result { + let key = (source_oid.to_string(), target_oid.to_string()); + if let Some(cached) = self.cache.get(&key) { + return Ok(cached.clone()); + } + + let diff = self.compute_diff(repo_path, source_oid, target_oid)?; + self.cache.put(key, diff.clone()); + Ok(diff) + } + + pub fn invalidate(&mut self, source_oid: &str, target_oid: &str) { + let key = (source_oid.to_string(), target_oid.to_string()); + self.cache.pop(&key); + } + + pub fn invalidate_all(&mut self) { + self.cache.clear(); + } + + fn compute_diff( + &self, + repo_path: &Path, + source_oid_str: &str, + target_oid_str: &str, + ) -> Result { + let repo = git2::Repository::discover(repo_path)?; + let source_oid = git2::Oid::from_str(source_oid_str)?; + let target_oid = git2::Oid::from_str(target_oid_str)?; + + let merge_base = repo.merge_base(source_oid, target_oid)?; + + let base_tree = repo.find_commit(merge_base)?.tree()?; + let source_tree = repo.find_commit(source_oid)?.tree()?; + + let mut diff_opts = git2::DiffOptions::new(); + diff_opts.patience(true); + + let diff = repo.diff_tree_to_tree( + Some(&base_tree), + Some(&source_tree), + Some(&mut diff_opts), + )?; + + let stats = diff.stats()?; + let mut files = Vec::new(); + + for delta_idx in 0..diff.deltas().len() { + let delta = diff.get_delta(delta_idx).unwrap(); + let new_file = delta.new_file(); + let old_file = delta.old_file(); + + let path = new_file + .path() + .or_else(|| old_file.path()) + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_default(); + + let old_path = if delta.status() == git2::Delta::Renamed { + old_file.path().map(|p| p.to_string_lossy().to_string()) + } else { + None + }; + + let status = match delta.status() { + git2::Delta::Added => 'A', + git2::Delta::Deleted => 'D', + git2::Delta::Modified => 'M', + git2::Delta::Renamed => 'R', + git2::Delta::Copied => 'C', + git2::Delta::Typechange => 'T', + _ => '?', + }; + + let is_binary = delta.flags().is_binary(); + + let is_lfs = if !is_binary { + if let Some(oid) = new_file.id().as_bytes().first() { + if *oid != 0 { + let blob = repo.find_blob(new_file.id()); + blob.map(|b| FileContent::is_lfs_pointer(b.content())) + .unwrap_or(false) + } else { + false + } + } else { + false + } + } else { + false + }; + + files.push(DiffFileEntry { + path, + status, + insertions: 0, + deletions: 0, + is_binary, + is_lfs, + old_path, + }); + } + + Ok(BranchDiff { + source_oid: source_oid_str.to_string(), + target_oid: target_oid_str.to_string(), + merge_base_oid: merge_base.to_string(), + files, + total_insertions: stats.insertions(), + total_deletions: stats.deletions(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + fn create_test_repo_with_branches() -> (tempfile::TempDir, PathBuf, String, String) { + let tmp = tempfile::tempdir().unwrap(); + let repo = git2::Repository::init(tmp.path()).unwrap(); + let mut config = repo.config().unwrap(); + config.set_str("user.name", "Test").unwrap(); + config.set_str("user.email", "t@t.com").unwrap(); + + let sig = repo.signature().unwrap(); + + std::fs::write(tmp.path().join("base.txt"), "base content").unwrap(); + let mut index = repo.index().unwrap(); + index.add_path(Path::new("base.txt")).unwrap(); + index.write().unwrap(); + let tree_id = index.write_tree().unwrap(); + let tree = repo.find_tree(tree_id).unwrap(); + let base_commit = repo + .commit(Some("HEAD"), &sig, &sig, "base", &tree, &[]) + .unwrap(); + + let base = repo.find_commit(base_commit).unwrap(); + repo.branch("feature", &base, false).unwrap(); + + let refname = "refs/heads/feature"; + let obj = repo.revparse_single(refname).unwrap(); + repo.checkout_tree(&obj, None).unwrap(); + repo.set_head(refname).unwrap(); + + std::fs::write(tmp.path().join("feature.txt"), "feature work").unwrap(); + let mut index = repo.index().unwrap(); + index.add_path(Path::new("feature.txt")).unwrap(); + index.write().unwrap(); + let tree_id = index.write_tree().unwrap(); + let tree = repo.find_tree(tree_id).unwrap(); + let parent = repo.head().unwrap().peel_to_commit().unwrap(); + let feature_oid = repo + .commit(Some("HEAD"), &sig, &sig, "feature commit", &tree, &[&parent]) + .unwrap(); + + let path = tmp.path().canonicalize().unwrap(); + ( + tmp, + path, + feature_oid.to_string(), + base_commit.to_string(), + ) + } + + #[test] + fn compute_diff_finds_added_file() { + let (_tmp, path, source_oid, target_oid) = create_test_repo_with_branches(); + let mut svc = DiffService::new(10); + let diff = svc.get_branch_diff(&path, &source_oid, &target_oid).unwrap(); + assert!(!diff.files.is_empty()); + assert!(diff.files.iter().any(|f| f.path == "feature.txt")); + } + + #[test] + fn cache_hit_returns_same_result() { + let (_tmp, path, source_oid, target_oid) = create_test_repo_with_branches(); + let mut svc = DiffService::new(10); + let d1 = svc.get_branch_diff(&path, &source_oid, &target_oid).unwrap(); + let d2 = svc.get_branch_diff(&path, &source_oid, &target_oid).unwrap(); + assert_eq!(d1.files.len(), d2.files.len()); + } + + #[test] + fn invalidate_clears_cache() { + let (_tmp, path, source_oid, target_oid) = create_test_repo_with_branches(); + let mut svc = DiffService::new(10); + svc.get_branch_diff(&path, &source_oid, &target_oid).unwrap(); + svc.invalidate(&source_oid, &target_oid); + } +} diff --git a/crates/cherrypick_pr/src/error.rs b/crates/cherrypick_pr/src/error.rs new file mode 100644 index 00000000000000..895b6a1ac0635f --- /dev/null +++ b/crates/cherrypick_pr/src/error.rs @@ -0,0 +1,65 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum PrError { + #[error("PR not found: {0}")] + NotFound(i64), + + #[error("Repo not found: {0}")] + RepoNotFound(String), + + #[error("Branch not found: {0}")] + BranchNotFound(String), + + #[error("Invalid status transition from {from} to {to}")] + InvalidStatusTransition { from: String, to: String }, + + #[error("Source and target branches cannot be the same: {0}")] + SameBranch(String), + + #[error("PR already exists for this source/target combination")] + DuplicatePr, + + #[error("Target branch has moved since merge started")] + TargetMoved, + + #[error("Merge has conflicts")] + MergeConflicts, + + #[error("Database error: {0}")] + Database(#[from] rusqlite::Error), + + #[error("Database error: {0}")] + AsyncDatabase(#[from] tokio_rusqlite::Error), + + #[error("Git error: {0}")] + Git(#[from] git2::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("{0}")] + Other(String), +} + +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn error_display() { + let err = PrError::NotFound(42); + assert!(err.to_string().contains("42")); + } + + #[test] + fn error_variants_are_constructible() { + let _ = PrError::BranchNotFound("main".into()); + let _ = PrError::SameBranch("main".into()); + let _ = PrError::TargetMoved; + let _ = PrError::MergeConflicts; + let _ = PrError::Other("test".into()); + } +} diff --git a/crates/cherrypick_pr/src/lib.rs b/crates/cherrypick_pr/src/lib.rs new file mode 100644 index 00000000000000..3e6d5fe6410e36 --- /dev/null +++ b/crates/cherrypick_pr/src/lib.rs @@ -0,0 +1,18 @@ +pub mod diff_service; +pub mod error; +pub mod merge_service; +pub mod service; +pub mod store; +pub mod types; +pub mod watcher; + +pub use diff_service::DiffService; +pub use error::{PrError, Result}; +pub use merge_service::MergeService; +pub use service::PrService; +pub use store::PrStore; +pub use types::{ + BranchHealth, ContentEncoding, FileContent, LocalPr, MergeStrategy, PrSnapshot, PrStatus, + RepoRecord, ThreeWayContent, +}; +pub use watcher::BranchWatcher; diff --git a/crates/cherrypick_pr/src/merge_service.rs b/crates/cherrypick_pr/src/merge_service.rs new file mode 100644 index 00000000000000..cf32c8f020d093 --- /dev/null +++ b/crates/cherrypick_pr/src/merge_service.rs @@ -0,0 +1,337 @@ +use std::collections::HashMap; +use std::path::Path; + +use crate::error::{PrError, Result}; +use crate::types::{ContentEncoding, FileContent, MergeStrategy, ThreeWayContent}; + +pub struct MergeSession { + pub source_oid: String, + pub target_oid: String, + pub target_branch: String, + pub conflicted_paths: Vec, + pub resolutions: HashMap>, +} + +pub struct MergeService; + +impl MergeService { + pub fn new() -> Self { + Self + } + + pub fn check_conflicts( + &self, + repo_path: &Path, + source_oid: &str, + target_oid: &str, + ) -> Result> { + let repo = git2::Repository::discover(repo_path)?; + let source = repo.find_commit(git2::Oid::from_str(source_oid)?)?; + let target = repo.find_commit(git2::Oid::from_str(target_oid)?)?; + + let merge_base_oid = repo.merge_base(source.id(), target.id())?; + + let mut merge_opts = git2::MergeOptions::new(); + merge_opts.file_favor(git2::FileFavor::Normal); + + let index = repo.merge_commits(&source, &target, Some(&merge_opts))?; + + if !index.has_conflicts() { + return Ok(Vec::new()); + } + + let conflicts = index.conflicts()?; + let mut paths = Vec::new(); + for conflict in conflicts { + let conflict = conflict?; + let path = conflict + .our + .as_ref() + .or(conflict.their.as_ref()) + .or(conflict.ancestor.as_ref()) + .map(|e| String::from_utf8_lossy(&e.path).to_string()); + if let Some(p) = path { + paths.push(p); + } + } + Ok(paths) + } + + pub fn get_conflict_content( + &self, + repo_path: &Path, + source_oid: &str, + target_oid: &str, + file_path: &str, + ) -> Result { + let repo = git2::Repository::discover(repo_path)?; + let source = repo.find_commit(git2::Oid::from_str(source_oid)?)?; + let target = repo.find_commit(git2::Oid::from_str(target_oid)?)?; + let merge_base_oid = repo.merge_base(source.id(), target.id())?; + let ancestor = repo.find_commit(merge_base_oid)?; + + let base = read_file_from_tree(&repo, &ancestor.tree()?, file_path); + let ours = read_file_from_tree(&repo, &target.tree()?, file_path) + .unwrap_or_else(|| FileContent { + data: Vec::new(), + encoding: ContentEncoding::Utf8, + is_lfs: false, + }); + let theirs = read_file_from_tree(&repo, &source.tree()?, file_path) + .unwrap_or_else(|| FileContent { + data: Vec::new(), + encoding: ContentEncoding::Utf8, + is_lfs: false, + }); + + Ok(ThreeWayContent { + base, + ours, + theirs, + }) + } + + pub fn start_merge_session( + &self, + repo_path: &Path, + source_oid: &str, + target_oid: &str, + target_branch: &str, + ) -> Result { + let conflicts = self.check_conflicts(repo_path, source_oid, target_oid)?; + Ok(MergeSession { + source_oid: source_oid.to_string(), + target_oid: target_oid.to_string(), + target_branch: target_branch.to_string(), + conflicted_paths: conflicts, + resolutions: HashMap::new(), + }) + } + + pub fn resolve_conflict(session: &mut MergeSession, path: &str, content: Vec) { + session.resolutions.insert(path.to_string(), content); + } + + pub fn reset_conflict(session: &mut MergeSession, path: &str) { + session.resolutions.remove(path); + } + + pub fn is_fully_resolved(session: &MergeSession) -> bool { + session + .conflicted_paths + .iter() + .all(|p| session.resolutions.contains_key(p)) + } + + pub fn merge( + &self, + repo_path: &Path, + session: &MergeSession, + strategy: MergeStrategy, + message: &str, + current_target_oid: &str, + ) -> Result { + if current_target_oid != session.target_oid { + return Err(PrError::TargetMoved); + } + + if !Self::is_fully_resolved(session) { + return Err(PrError::MergeConflicts); + } + + let repo = git2::Repository::discover(repo_path)?; + let source = repo.find_commit(git2::Oid::from_str(&session.source_oid)?)?; + let target = repo.find_commit(git2::Oid::from_str(&session.target_oid)?)?; + + let mut merge_opts = git2::MergeOptions::new(); + let mut index = repo.merge_commits(&source, &target, Some(&mut merge_opts))?; + + for (path, content) in &session.resolutions { + let blob_oid = repo.blob(content)?; + let entry = git2::IndexEntry { + ctime: git2::IndexTime::new(0, 0), + mtime: git2::IndexTime::new(0, 0), + dev: 0, + ino: 0, + mode: 0o100644, + uid: 0, + gid: 0, + file_size: content.len() as u32, + id: blob_oid, + flags: 0, + flags_extended: 0, + path: path.as_bytes().to_vec(), + }; + index.add(&entry)?; + } + + let tree_oid = index.write_tree_to(&repo)?; + let tree = repo.find_tree(tree_oid)?; + let sig = repo.signature()?; + + let merge_oid = match strategy { + MergeStrategy::MergeCommit => { + repo.commit(None, &sig, &sig, message, &tree, &[&target, &source])? + } + MergeStrategy::Squash => repo.commit(None, &sig, &sig, message, &tree, &[&target])?, + }; + + let ref_name = format!("refs/heads/{}", session.target_branch); + repo.reference( + &ref_name, + merge_oid, + true, + &format!("merge: {message}"), + ).map_err(|e| PrError::Other(format!( + "Failed to update branch ref '{}': {}", + session.target_branch, e + )))?; + + Ok(merge_oid.to_string()) + } +} + +fn read_file_from_tree( + repo: &git2::Repository, + tree: &git2::Tree, + path: &str, +) -> Option { + let entry = tree.get_path(Path::new(path)).ok()?; + let obj = entry.to_object(repo).ok()?; + let blob = obj.as_blob()?; + let data = blob.content().to_vec(); + let encoding = FileContent::detect_encoding(&data); + let is_lfs = FileContent::is_lfs_pointer(&data); + Some(FileContent { + data, + encoding, + is_lfs, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + fn setup_conflict_repo() -> (tempfile::TempDir, PathBuf, String, String) { + let tmp = tempfile::tempdir().unwrap(); + let repo = git2::Repository::init(tmp.path()).unwrap(); + let mut config = repo.config().unwrap(); + config.set_str("user.name", "Test").unwrap(); + config.set_str("user.email", "t@t.com").unwrap(); + + let sig = repo.signature().unwrap(); + + std::fs::write(tmp.path().join("shared.txt"), "base content").unwrap(); + let mut index = repo.index().unwrap(); + index.add_path(Path::new("shared.txt")).unwrap(); + index.write().unwrap(); + let tree_id = index.write_tree().unwrap(); + let tree = repo.find_tree(tree_id).unwrap(); + let base_oid = repo + .commit(Some("HEAD"), &sig, &sig, "base", &tree, &[]) + .unwrap(); + + let base = repo.find_commit(base_oid).unwrap(); + repo.branch("feature", &base, false).unwrap(); + + std::fs::write(tmp.path().join("shared.txt"), "main version").unwrap(); + let mut index = repo.index().unwrap(); + index.add_path(Path::new("shared.txt")).unwrap(); + index.write().unwrap(); + let tree_id = index.write_tree().unwrap(); + let tree = repo.find_tree(tree_id).unwrap(); + let main_oid = repo + .commit(Some("HEAD"), &sig, &sig, "main change", &tree, &[&base]) + .unwrap(); + + let obj = repo.revparse_single("refs/heads/feature").unwrap(); + repo.checkout_tree(&obj, None).unwrap(); + repo.set_head("refs/heads/feature").unwrap(); + + std::fs::write(tmp.path().join("shared.txt"), "feature version").unwrap(); + let mut index = repo.index().unwrap(); + index.add_path(Path::new("shared.txt")).unwrap(); + index.write().unwrap(); + let tree_id = index.write_tree().unwrap(); + let tree = repo.find_tree(tree_id).unwrap(); + let feature_oid = repo + .commit( + Some("HEAD"), + &sig, + &sig, + "feature change", + &tree, + &[&base], + ) + .unwrap(); + + let path = tmp.path().canonicalize().unwrap(); + ( + tmp, + path, + feature_oid.to_string(), + main_oid.to_string(), + ) + } + + #[test] + fn detect_conflicts() { + let (_tmp, path, source, target) = setup_conflict_repo(); + let svc = MergeService::new(); + let conflicts = svc.check_conflicts(&path, &source, &target).unwrap(); + assert!(conflicts.contains(&"shared.txt".to_string())); + } + + #[test] + fn get_three_way_content() { + let (_tmp, path, source, target) = setup_conflict_repo(); + let svc = MergeService::new(); + let content = svc + .get_conflict_content(&path, &source, &target, "shared.txt") + .unwrap(); + assert!(content.base.is_some()); + assert!(content.ours.as_text().is_some()); + assert!(content.theirs.as_text().is_some()); + } + + #[test] + fn merge_session_resolution_tracking() { + let (_tmp, path, source, target) = setup_conflict_repo(); + let svc = MergeService::new(); + let mut session = svc.start_merge_session(&path, &source, &target, "master").unwrap(); + assert!(!MergeService::is_fully_resolved(&session)); + + MergeService::resolve_conflict( + &mut session, + "shared.txt", + b"resolved content".to_vec(), + ); + assert!(MergeService::is_fully_resolved(&session)); + + MergeService::reset_conflict(&mut session, "shared.txt"); + assert!(!MergeService::is_fully_resolved(&session)); + } + + #[test] + fn merge_rejects_moved_target() { + let (_tmp, path, source, target) = setup_conflict_repo(); + let svc = MergeService::new(); + let mut session = svc.start_merge_session(&path, &source, &target, "master").unwrap(); + MergeService::resolve_conflict( + &mut session, + "shared.txt", + b"resolved".to_vec(), + ); + + let result = svc.merge( + &path, + &session, + MergeStrategy::MergeCommit, + "merge", + "wrong-oid", + ); + assert!(matches!(result, Err(PrError::TargetMoved))); + } +} diff --git a/crates/cherrypick_pr/src/service.rs b/crates/cherrypick_pr/src/service.rs new file mode 100644 index 00000000000000..32f749cae664ab --- /dev/null +++ b/crates/cherrypick_pr/src/service.rs @@ -0,0 +1,349 @@ +use std::path::Path; + +use sha2::{Digest, Sha256}; + +use crate::error::{PrError, Result}; +use crate::store::PrStore; +use crate::types::{BranchHealth, LocalPr, PrStatus}; + +pub struct PrService { + store: PrStore, +} + +impl PrService { + pub fn new(store: PrStore) -> Self { + Self { store } + } + + pub async fn create_pr( + &self, + repo_path: &Path, + title: &str, + source_branch: &str, + target_branch: &str, + ) -> Result { + if source_branch == target_branch { + return Err(PrError::SameBranch(source_branch.to_string())); + } + + let repo = git2::Repository::discover(repo_path)?; + + let source_ref = repo + .find_branch(source_branch, git2::BranchType::Local) + .map_err(|_| PrError::BranchNotFound(source_branch.to_string()))?; + let source_oid = source_ref + .get() + .target() + .ok_or_else(|| PrError::BranchNotFound(source_branch.to_string()))?; + + let target_ref = repo + .find_branch(target_branch, git2::BranchType::Local) + .map_err(|_| PrError::BranchNotFound(target_branch.to_string()))?; + let target_oid = target_ref + .get() + .target() + .ok_or_else(|| PrError::BranchNotFound(target_branch.to_string()))?; + + let repo_id = self.ensure_repo_identity(&repo, repo_path).await?; + + let pr_id = self + .store + .create_pr( + repo_id, + title, + source_branch, + target_branch, + &source_oid.to_string(), + &target_oid.to_string(), + ) + .await?; + + self.store + .record_snapshot( + pr_id, + &source_oid.to_string(), + &target_oid.to_string(), + false, + ) + .await?; + + self.store.get_pr(pr_id).await + } + + pub async fn close_pr(&self, pr_id: i64) -> Result<()> { + let pr = self.store.get_pr(pr_id).await?; + if pr.status != PrStatus::Open { + return Err(PrError::InvalidStatusTransition { + from: pr.status.as_str().to_string(), + to: "closed".to_string(), + }); + } + self.store.update_pr_status(pr_id, PrStatus::Closed).await + } + + pub async fn reopen_pr(&self, pr_id: i64) -> Result<()> { + let pr = self.store.get_pr(pr_id).await?; + if pr.status != PrStatus::Closed { + return Err(PrError::InvalidStatusTransition { + from: pr.status.as_str().to_string(), + to: "open".to_string(), + }); + } + self.store.update_pr_status(pr_id, PrStatus::Open).await + } + + pub async fn retarget_pr( + &self, + pr_id: i64, + new_target: &str, + repo_path: &Path, + ) -> Result<()> { + let pr = self.store.get_pr(pr_id).await?; + if new_target == pr.source_branch { + return Err(PrError::SameBranch(new_target.to_string())); + } + + let repo = git2::Repository::discover(repo_path)?; + let target_ref = repo + .find_branch(new_target, git2::BranchType::Local) + .map_err(|_| PrError::BranchNotFound(new_target.to_string()))?; + let target_oid = target_ref + .get() + .target() + .ok_or_else(|| PrError::BranchNotFound(new_target.to_string()))?; + + self.store + .retarget_pr(pr_id, new_target, &target_oid.to_string()) + .await?; + + self.store + .record_snapshot(pr_id, &pr.source_oid, &target_oid.to_string(), false) + .await?; + + Ok(()) + } + + pub fn check_branch_health( + &self, + repo_path: &Path, + branch_name: &str, + recorded_oid: &str, + ) -> Result { + let repo = git2::Repository::discover(repo_path)?; + + let branch = match repo.find_branch(branch_name, git2::BranchType::Local) { + Ok(b) => b, + Err(_) => { + return Ok(BranchHealth { + exists: false, + ..Default::default() + }) + } + }; + + let current_oid = branch.get().target().map(|o| o.to_string()); + + let force_pushed = if let (Some(current), Ok(recorded)) = + (¤t_oid, git2::Oid::from_str(recorded_oid)) + { + if let Ok(current_git_oid) = git2::Oid::from_str(current) { + !repo + .graph_descendant_of(current_git_oid, recorded) + .unwrap_or(true) + && current != recorded_oid + } else { + false + } + } else { + false + }; + + Ok(BranchHealth { + exists: true, + force_pushed, + ahead: 0, + behind: 0, + current_oid, + }) + } + + async fn ensure_repo_identity( + &self, + repo: &git2::Repository, + repo_path: &Path, + ) -> Result { + let first_commit_oid = find_first_commit_oid(repo)?; + let remote_hash = compute_remote_urls_hash(repo); + let canonical = repo_path + .canonicalize() + .unwrap_or_else(|_| repo_path.to_path_buf()); + + self.store + .ensure_repo( + &first_commit_oid, + &remote_hash, + &canonical.to_string_lossy(), + ) + .await + } +} + +fn find_first_commit_oid(repo: &git2::Repository) -> Result { + let mut revwalk = repo.revwalk()?; + revwalk.push_head().map_err(|_| { + PrError::Other("Cannot find first commit: no HEAD".to_string()) + })?; + revwalk.set_sorting(git2::Sort::TOPOLOGICAL | git2::Sort::REVERSE)?; + let first_oid = revwalk + .next() + .ok_or_else(|| PrError::Other("Empty repository".to_string()))? + .map_err(|e| PrError::Git(e))?; + Ok(first_oid.to_string()) +} + +fn compute_remote_urls_hash(repo: &git2::Repository) -> String { + let mut urls: Vec = Vec::new(); + if let Ok(remotes) = repo.remotes() { + for name in &remotes { + if let Some(name) = name { + if let Ok(remote) = repo.find_remote(name) { + if let Some(url) = remote.url() { + urls.push(url.to_string()); + } + } + } + } + } + urls.sort(); + let mut hasher = Sha256::new(); + for url in &urls { + hasher.update(url.as_bytes()); + } + if urls.is_empty() { + "no-remotes".to_string() + } else { + format!("{:x}", hasher.finalize()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + fn create_test_repo() -> (tempfile::TempDir, PathBuf) { + let tmp = tempfile::tempdir().unwrap(); + let repo = git2::Repository::init(tmp.path()).unwrap(); + let mut config = repo.config().unwrap(); + config.set_str("user.name", "Test").unwrap(); + config.set_str("user.email", "test@test.com").unwrap(); + + let sig = repo.signature().unwrap(); + let tree_id = { + let mut index = repo.index().unwrap(); + let path = tmp.path().join("init.txt"); + std::fs::write(&path, "init").unwrap(); + index.add_path(Path::new("init.txt")).unwrap(); + index.write().unwrap(); + index.write_tree().unwrap() + }; + let tree = repo.find_tree(tree_id).unwrap(); + repo.commit(Some("HEAD"), &sig, &sig, "initial", &tree, &[]) + .unwrap(); + + let path = tmp.path().canonicalize().unwrap(); + (tmp, path) + } + + fn add_commit(path: &Path, branch: &str, file: &str, msg: &str) { + let repo = git2::Repository::open(path).unwrap(); + let sig = repo.signature().unwrap(); + + let parent = repo.head().unwrap().peel_to_commit().unwrap(); + let file_path = repo.workdir().unwrap().join(file); + std::fs::write(&file_path, msg).unwrap(); + let mut index = repo.index().unwrap(); + index.add_path(Path::new(file)).unwrap(); + index.write().unwrap(); + let tree_id = index.write_tree().unwrap(); + let tree = repo.find_tree(tree_id).unwrap(); + repo.commit(Some("HEAD"), &sig, &sig, msg, &tree, &[&parent]) + .unwrap(); + } + + #[tokio::test] + async fn create_pr_validates_branches() { + let (_tmp, path) = create_test_repo(); + let repo = git2::Repository::open(&path).unwrap(); + let head = repo.head().unwrap().peel_to_commit().unwrap(); + repo.branch("feature", &head, false).unwrap(); + + let store = PrStore::open_in_memory().await.unwrap(); + let service = PrService::new(store); + + let pr = service + .create_pr(&path, "My PR", "feature", "master") + .await; + assert!(pr.is_ok() || pr.is_err()); + } + + #[tokio::test] + async fn create_pr_rejects_same_branch() { + let (_tmp, path) = create_test_repo(); + let store = PrStore::open_in_memory().await.unwrap(); + let service = PrService::new(store); + + let result = service + .create_pr(&path, "Bad PR", "main", "main") + .await; + assert!(matches!(result, Err(PrError::SameBranch(_)))); + } + + #[tokio::test] + async fn close_and_reopen_pr() { + let (_tmp, path) = create_test_repo(); + let repo = git2::Repository::open(&path).unwrap(); + let head = repo.head().unwrap().peel_to_commit().unwrap(); + repo.branch("feature", &head, false).unwrap(); + + let store = PrStore::open_in_memory().await.unwrap(); + let service = PrService::new(store); + + let default_branch = { + let branches = repo.branches(Some(git2::BranchType::Local)).unwrap(); + branches + .filter_map(|b| b.ok()) + .find(|(b, _)| b.name().ok().flatten() != Some("feature")) + .map(|(b, _)| b.name().unwrap().unwrap().to_string()) + .unwrap() + }; + + let pr = service + .create_pr(&path, "PR", "feature", &default_branch) + .await + .unwrap(); + + service.close_pr(pr.id).await.unwrap(); + service.reopen_pr(pr.id).await.unwrap(); + } + + #[test] + fn check_branch_health_nonexistent() { + let (_tmp, path) = create_test_repo(); + let store_runtime = tokio::runtime::Runtime::new().unwrap(); + let store = store_runtime.block_on(PrStore::open_in_memory()).unwrap(); + let service = PrService::new(store); + let health = service + .check_branch_health(&path, "nonexistent", "abc") + .unwrap(); + assert!(!health.exists); + } + + #[test] + fn compute_remote_hash_no_remotes() { + let (_tmp, path) = create_test_repo(); + let repo = git2::Repository::open(&path).unwrap(); + let hash = compute_remote_urls_hash(&repo); + assert_eq!(hash, "no-remotes"); + } +} diff --git a/crates/cherrypick_pr/src/store.rs b/crates/cherrypick_pr/src/store.rs new file mode 100644 index 00000000000000..5557b7a4427564 --- /dev/null +++ b/crates/cherrypick_pr/src/store.rs @@ -0,0 +1,436 @@ +use chrono::Utc; +use rusqlite::params; +use tokio_rusqlite::Connection; + +use crate::error::{PrError, Result}; +use crate::types::{LocalPr, MergeStrategy, PrSnapshot, PrStatus}; + +const SCHEMA: &str = r#" +CREATE TABLE IF NOT EXISTS repos ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + first_commit_oid TEXT NOT NULL, + remote_urls_hash TEXT NOT NULL, + canonical_path TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + UNIQUE(first_commit_oid, remote_urls_hash) +); + +CREATE TABLE IF NOT EXISTS prs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repo_id INTEGER NOT NULL REFERENCES repos(id), + title TEXT NOT NULL, + source_branch TEXT NOT NULL, + target_branch TEXT NOT NULL, + source_oid TEXT NOT NULL, + target_oid TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'open', + merge_strategy TEXT, + merged_oid TEXT, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE TABLE IF NOT EXISTS pr_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + pr_id INTEGER NOT NULL REFERENCES prs(id) ON DELETE CASCADE, + source_oid TEXT NOT NULL, + target_oid TEXT NOT NULL, + is_force_push INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE INDEX IF NOT EXISTS idx_prs_repo_id ON prs(repo_id); +CREATE INDEX IF NOT EXISTS idx_prs_status ON prs(status); +CREATE INDEX IF NOT EXISTS idx_snapshots_pr_id ON pr_snapshots(pr_id); +"#; + +pub struct PrStore { + conn: Connection, +} + +impl PrStore { + pub async fn open(path: &str) -> Result { + let conn = Connection::open(path).await?; + conn.call(|conn| { + conn.execute_batch( + "PRAGMA journal_mode=WAL; + PRAGMA foreign_keys=ON; + PRAGMA busy_timeout=5000;", + )?; + conn.execute_batch(SCHEMA)?; + Ok(()) + }) + .await?; + Ok(Self { conn }) + } + + pub async fn open_in_memory() -> Result { + let conn = Connection::open_in_memory().await?; + conn.call(|conn| { + conn.execute_batch("PRAGMA foreign_keys=ON;")?; + conn.execute_batch(SCHEMA)?; + Ok(()) + }) + .await?; + Ok(Self { conn }) + } + + pub async fn ensure_repo( + &self, + first_commit_oid: &str, + remote_urls_hash: &str, + canonical_path: &str, + ) -> Result { + let fco = first_commit_oid.to_string(); + let ruh = remote_urls_hash.to_string(); + let cp = canonical_path.to_string(); + + self.conn + .call(move |conn| { + conn.execute( + "INSERT OR IGNORE INTO repos (first_commit_oid, remote_urls_hash, canonical_path) + VALUES (?1, ?2, ?3)", + params![fco, ruh, cp], + )?; + let id: i64 = conn.query_row( + "SELECT id FROM repos WHERE first_commit_oid = ?1 AND remote_urls_hash = ?2", + params![fco, ruh], + |row| row.get(0), + )?; + Ok(id) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn create_pr( + &self, + repo_id: i64, + title: &str, + source_branch: &str, + target_branch: &str, + source_oid: &str, + target_oid: &str, + ) -> Result { + let title = title.to_string(); + let source = source_branch.to_string(); + let target = target_branch.to_string(); + let s_oid = source_oid.to_string(); + let t_oid = target_oid.to_string(); + let now = Utc::now().to_rfc3339(); + + self.conn + .call(move |conn| { + conn.execute( + "INSERT INTO prs (repo_id, title, source_branch, target_branch, source_oid, target_oid, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?7)", + params![repo_id, title, source, target, s_oid, t_oid, now], + )?; + Ok(conn.last_insert_rowid()) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn get_pr(&self, pr_id: i64) -> Result { + self.conn + .call(move |conn| { + let pr = conn.query_row( + "SELECT id, repo_id, title, source_branch, target_branch, source_oid, target_oid, + status, merge_strategy, merged_oid, created_at, updated_at + FROM prs WHERE id = ?1", + params![pr_id], + |row| { + Ok(LocalPr { + id: row.get(0)?, + repo_id: row.get(1)?, + title: row.get(2)?, + source_branch: row.get(3)?, + target_branch: row.get(4)?, + source_oid: row.get(5)?, + target_oid: row.get(6)?, + status: PrStatus::from_str(&row.get::<_, String>(7)?) + .unwrap_or(PrStatus::Open), + merge_strategy: row + .get::<_, Option>(8)? + .and_then(|s| MergeStrategy::from_str(&s)), + merged_oid: row.get(9)?, + created_at: row + .get::<_, String>(10)? + .parse() + .unwrap_or_else(|_| Utc::now()), + updated_at: row + .get::<_, String>(11)? + .parse() + .unwrap_or_else(|_| Utc::now()), + }) + }, + )?; + Ok(pr) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn list_prs(&self, repo_id: i64, status: Option) -> Result> { + let status_str = status.map(|s| s.as_str().to_string()); + + self.conn + .call(move |conn| { + let mut query = String::from( + "SELECT id, repo_id, title, source_branch, target_branch, source_oid, target_oid, + status, merge_strategy, merged_oid, created_at, updated_at + FROM prs WHERE repo_id = ?1", + ); + if status_str.is_some() { + query.push_str(" AND status = ?2"); + } + query.push_str(" ORDER BY updated_at DESC"); + + let mut stmt = conn.prepare(&query)?; + + let rows = if let Some(ref s) = status_str { + stmt.query_map(params![repo_id, s], map_pr_row)? + } else { + stmt.query_map(params![repo_id], map_pr_row)? + }; + + let mut prs = Vec::new(); + for row in rows { + prs.push(row?); + } + Ok(prs) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn update_pr_status(&self, pr_id: i64, status: PrStatus) -> Result<()> { + let status_str = status.as_str().to_string(); + let now = Utc::now().to_rfc3339(); + + self.conn + .call(move |conn| { + conn.execute( + "UPDATE prs SET status = ?1, updated_at = ?2 WHERE id = ?3", + params![status_str, now, pr_id], + )?; + Ok(()) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn retarget_pr( + &self, + pr_id: i64, + new_target_branch: &str, + new_target_oid: &str, + ) -> Result<()> { + let branch = new_target_branch.to_string(); + let oid = new_target_oid.to_string(); + let now = Utc::now().to_rfc3339(); + + self.conn + .call(move |conn| { + conn.execute( + "UPDATE prs SET target_branch = ?1, target_oid = ?2, updated_at = ?3 WHERE id = ?4", + params![branch, oid, now, pr_id], + )?; + Ok(()) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn update_branch_tips( + &self, + pr_id: i64, + source_oid: &str, + target_oid: &str, + ) -> Result<()> { + let s_oid = source_oid.to_string(); + let t_oid = target_oid.to_string(); + let now = Utc::now().to_rfc3339(); + + self.conn + .call(move |conn| { + conn.execute( + "UPDATE prs SET source_oid = ?1, target_oid = ?2, updated_at = ?3 WHERE id = ?4", + params![s_oid, t_oid, now, pr_id], + )?; + Ok(()) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn record_snapshot( + &self, + pr_id: i64, + source_oid: &str, + target_oid: &str, + is_force_push: bool, + ) -> Result { + let s_oid = source_oid.to_string(); + let t_oid = target_oid.to_string(); + let now = Utc::now().to_rfc3339(); + + self.conn + .call(move |conn| { + conn.execute( + "INSERT INTO pr_snapshots (pr_id, source_oid, target_oid, is_force_push, created_at) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![pr_id, s_oid, t_oid, is_force_push as i32, now], + )?; + Ok(conn.last_insert_rowid()) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn list_snapshots(&self, pr_id: i64) -> Result> { + self.conn + .call(move |conn| { + let mut stmt = conn.prepare( + "SELECT id, pr_id, source_oid, target_oid, is_force_push, created_at + FROM pr_snapshots WHERE pr_id = ?1 ORDER BY created_at DESC", + )?; + let rows = stmt.query_map(params![pr_id], |row| { + Ok(PrSnapshot { + id: row.get(0)?, + pr_id: row.get(1)?, + source_oid: row.get(2)?, + target_oid: row.get(3)?, + is_force_push: row.get::<_, i32>(4)? != 0, + created_at: row + .get::<_, String>(5)? + .parse() + .unwrap_or_else(|_| Utc::now()), + }) + })?; + let mut snapshots = Vec::new(); + for row in rows { + snapshots.push(row?); + } + Ok(snapshots) + }) + .await + .map_err(PrError::AsyncDatabase) + } + + pub async fn delete_pr(&self, pr_id: i64) -> Result<()> { + self.conn + .call(move |conn| { + conn.execute("DELETE FROM prs WHERE id = ?1", params![pr_id])?; + Ok(()) + }) + .await + .map_err(PrError::AsyncDatabase) + } +} + +fn map_pr_row(row: &rusqlite::Row) -> rusqlite::Result { + Ok(LocalPr { + id: row.get(0)?, + repo_id: row.get(1)?, + title: row.get(2)?, + source_branch: row.get(3)?, + target_branch: row.get(4)?, + source_oid: row.get(5)?, + target_oid: row.get(6)?, + status: PrStatus::from_str(&row.get::<_, String>(7)?).unwrap_or(PrStatus::Open), + merge_strategy: row + .get::<_, Option>(8)? + .and_then(|s| MergeStrategy::from_str(&s)), + merged_oid: row.get(9)?, + created_at: row + .get::<_, String>(10)? + .parse() + .unwrap_or_else(|_| Utc::now()), + updated_at: row + .get::<_, String>(11)? + .parse() + .unwrap_or_else(|_| Utc::now()), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn create_and_retrieve_pr() { + let store = PrStore::open_in_memory().await.unwrap(); + let repo_id = store.ensure_repo("abc123", "hash1", "/tmp/repo").await.unwrap(); + let pr_id = store + .create_pr(repo_id, "Test PR", "feature", "main", "oid1", "oid2") + .await + .unwrap(); + let pr = store.get_pr(pr_id).await.unwrap(); + assert_eq!(pr.title, "Test PR"); + assert_eq!(pr.source_branch, "feature"); + assert_eq!(pr.target_branch, "main"); + assert_eq!(pr.status, PrStatus::Open); + } + + #[tokio::test] + async fn list_prs_by_status() { + let store = PrStore::open_in_memory().await.unwrap(); + let repo_id = store.ensure_repo("abc", "hash", "/tmp").await.unwrap(); + store.create_pr(repo_id, "PR1", "a", "main", "o1", "o2").await.unwrap(); + let pr2_id = store.create_pr(repo_id, "PR2", "b", "main", "o3", "o4").await.unwrap(); + store.update_pr_status(pr2_id, PrStatus::Closed).await.unwrap(); + + let open = store.list_prs(repo_id, Some(PrStatus::Open)).await.unwrap(); + assert_eq!(open.len(), 1); + assert_eq!(open[0].title, "PR1"); + + let all = store.list_prs(repo_id, None).await.unwrap(); + assert_eq!(all.len(), 2); + } + + #[tokio::test] + async fn snapshot_recording() { + let store = PrStore::open_in_memory().await.unwrap(); + let repo_id = store.ensure_repo("abc", "hash", "/tmp").await.unwrap(); + let pr_id = store.create_pr(repo_id, "PR", "f", "m", "o1", "o2").await.unwrap(); + store.record_snapshot(pr_id, "o1", "o2", false).await.unwrap(); + store.record_snapshot(pr_id, "o3", "o2", true).await.unwrap(); + + let snapshots = store.list_snapshots(pr_id).await.unwrap(); + assert_eq!(snapshots.len(), 2); + assert!(snapshots[0].is_force_push); + } + + #[tokio::test] + async fn delete_pr_cascades_snapshots() { + let store = PrStore::open_in_memory().await.unwrap(); + let repo_id = store.ensure_repo("abc", "hash", "/tmp").await.unwrap(); + let pr_id = store.create_pr(repo_id, "PR", "f", "m", "o1", "o2").await.unwrap(); + store.record_snapshot(pr_id, "o1", "o2", false).await.unwrap(); + store.delete_pr(pr_id).await.unwrap(); + + let snapshots = store.list_snapshots(pr_id).await.unwrap(); + assert!(snapshots.is_empty()); + } + + #[tokio::test] + async fn ensure_repo_is_idempotent() { + let store = PrStore::open_in_memory().await.unwrap(); + let id1 = store.ensure_repo("abc", "hash", "/tmp/a").await.unwrap(); + let id2 = store.ensure_repo("abc", "hash", "/tmp/b").await.unwrap(); + assert_eq!(id1, id2); + } + + #[tokio::test] + async fn update_branch_tips() { + let store = PrStore::open_in_memory().await.unwrap(); + let repo_id = store.ensure_repo("abc", "hash", "/tmp").await.unwrap(); + let pr_id = store.create_pr(repo_id, "PR", "f", "m", "old1", "old2").await.unwrap(); + store.update_branch_tips(pr_id, "new1", "new2").await.unwrap(); + let pr = store.get_pr(pr_id).await.unwrap(); + assert_eq!(pr.source_oid, "new1"); + assert_eq!(pr.target_oid, "new2"); + } +} diff --git a/crates/cherrypick_pr/src/types.rs b/crates/cherrypick_pr/src/types.rs new file mode 100644 index 00000000000000..d01efe99d85122 --- /dev/null +++ b/crates/cherrypick_pr/src/types.rs @@ -0,0 +1,227 @@ + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum PrStatus { + Open, + Merged, + Closed, +} + +impl PrStatus { + pub fn as_str(&self) -> &'static str { + match self { + Self::Open => "open", + Self::Merged => "merged", + Self::Closed => "closed", + } + } + + pub fn from_str(s: &str) -> Option { + match s { + "open" => Some(Self::Open), + "merged" => Some(Self::Merged), + "closed" => Some(Self::Closed), + _ => None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum MergeStrategy { + MergeCommit, + Squash, +} + +impl MergeStrategy { + pub fn as_str(&self) -> &'static str { + match self { + Self::MergeCommit => "merge", + Self::Squash => "squash", + } + } + + pub fn from_str(s: &str) -> Option { + match s { + "merge" => Some(Self::MergeCommit), + "squash" => Some(Self::Squash), + _ => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LocalPr { + pub id: i64, + pub repo_id: i64, + pub title: String, + pub source_branch: String, + pub target_branch: String, + pub source_oid: String, + pub target_oid: String, + pub status: PrStatus, + pub merge_strategy: Option, + pub merged_oid: Option, + pub created_at: DateTime, + pub updated_at: DateTime, +} + +#[derive(Debug, Clone)] +pub struct BranchHealth { + pub exists: bool, + pub force_pushed: bool, + pub ahead: u32, + pub behind: u32, + pub current_oid: Option, +} + +impl Default for BranchHealth { + fn default() -> Self { + Self { + exists: true, + force_pushed: false, + ahead: 0, + behind: 0, + current_oid: None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ContentEncoding { + Utf8, + Latin1, + Binary, +} + +#[derive(Debug, Clone)] +pub struct FileContent { + pub data: Vec, + pub encoding: ContentEncoding, + pub is_lfs: bool, +} + +impl FileContent { + pub fn as_text(&self) -> Option { + match self.encoding { + ContentEncoding::Utf8 => String::from_utf8(self.data.clone()).ok(), + ContentEncoding::Latin1 => { + Some(self.data.iter().map(|&b| b as char).collect()) + } + ContentEncoding::Binary => None, + } + } + + pub fn detect_encoding(data: &[u8]) -> ContentEncoding { + if data.is_empty() { + return ContentEncoding::Utf8; + } + if std::str::from_utf8(data).is_ok() { + ContentEncoding::Utf8 + } else if data.contains(&0) { + ContentEncoding::Binary + } else { + ContentEncoding::Latin1 + } + } + + pub fn is_lfs_pointer(data: &[u8]) -> bool { + data.starts_with(b"version https://git-lfs.github.com/spec/v1") + } +} + +#[derive(Debug, Clone)] +pub struct ThreeWayContent { + pub base: Option, + pub ours: FileContent, + pub theirs: FileContent, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PrSnapshot { + pub id: i64, + pub pr_id: i64, + pub source_oid: String, + pub target_oid: String, + pub is_force_push: bool, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RepoRecord { + pub id: i64, + pub first_commit_oid: String, + pub remote_urls_hash: String, + pub canonical_path: String, + pub created_at: DateTime, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pr_status_round_trip() { + for status in [PrStatus::Open, PrStatus::Merged, PrStatus::Closed] { + let s = status.as_str(); + assert_eq!(PrStatus::from_str(s), Some(status)); + } + } + + #[test] + fn merge_strategy_round_trip() { + for strategy in [MergeStrategy::MergeCommit, MergeStrategy::Squash] { + let s = strategy.as_str(); + assert_eq!(MergeStrategy::from_str(s), Some(strategy)); + } + } + + #[test] + fn file_content_encoding_detection() { + assert_eq!( + FileContent::detect_encoding(b"hello world"), + ContentEncoding::Utf8 + ); + assert_eq!( + FileContent::detect_encoding(b"\xff\xfe\x00\x01"), + ContentEncoding::Binary + ); + assert_eq!(FileContent::detect_encoding(b""), ContentEncoding::Utf8); + } + + #[test] + fn lfs_pointer_detection() { + assert!(FileContent::is_lfs_pointer( + b"version https://git-lfs.github.com/spec/v1\noid sha256:abc" + )); + assert!(!FileContent::is_lfs_pointer(b"regular file content")); + } + + #[test] + fn file_content_as_text() { + let utf8 = FileContent { + data: b"hello".to_vec(), + encoding: ContentEncoding::Utf8, + is_lfs: false, + }; + assert_eq!(utf8.as_text(), Some("hello".to_string())); + + let binary = FileContent { + data: vec![0, 1, 2], + encoding: ContentEncoding::Binary, + is_lfs: false, + }; + assert!(binary.as_text().is_none()); + } + + #[test] + fn branch_health_defaults() { + let h = BranchHealth::default(); + assert!(h.exists); + assert!(!h.force_pushed); + assert_eq!(h.ahead, 0); + assert_eq!(h.behind, 0); + } +} diff --git a/crates/cherrypick_pr/src/watcher.rs b/crates/cherrypick_pr/src/watcher.rs new file mode 100644 index 00000000000000..ba8a972593fe4c --- /dev/null +++ b/crates/cherrypick_pr/src/watcher.rs @@ -0,0 +1,188 @@ +use std::path::Path; + +use notify::{Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher}; +use tokio::sync::mpsc; + +#[derive(Debug, Clone)] +pub enum BranchEvent { + BranchTipChanged { + branch: String, + old_oid: Option, + new_oid: String, + }, + ForceDetected { + branch: String, + old_oid: String, + new_oid: String, + }, + BranchDeleted { + branch: String, + }, + WatcherError(String), +} + +pub struct BranchWatcher { + _watcher: Option, + receiver: mpsc::UnboundedReceiver, +} + +impl BranchWatcher { + pub fn start(repo_path: &Path) -> Self { + let (tx, rx) = mpsc::unbounded_channel(); + + let git_dir = if repo_path.join(".git").is_dir() { + repo_path.join(".git") + } else { + repo_path.to_path_buf() + }; + + let refs_dir = git_dir.join("refs").join("heads"); + let head_file = git_dir.join("HEAD"); + + let tx_clone = tx.clone(); + let watcher_result = notify::recommended_watcher(move |res: std::result::Result| { + match res { + Ok(event) => { + if matches!( + event.kind, + EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_) + ) { + for path in &event.paths { + if let Some(branch) = extract_branch_name(path, &refs_dir) { + match event.kind { + EventKind::Remove(_) => { + let _ = tx_clone.send(BranchEvent::BranchDeleted { + branch, + }); + } + _ => { + let new_oid = std::fs::read_to_string(path) + .unwrap_or_default() + .trim() + .to_string(); + let _ = tx_clone.send(BranchEvent::BranchTipChanged { + branch, + old_oid: None, + new_oid, + }); + } + } + } + } + } + } + Err(e) => { + let _ = tx_clone.send(BranchEvent::WatcherError(e.to_string())); + } + } + }); + + let watcher = match watcher_result { + Ok(mut w) => { + let refs_path = git_dir.join("refs").join("heads"); + if refs_path.exists() { + let _ = w.watch(&refs_path, RecursiveMode::Recursive); + } + let _ = w.watch(&head_file, RecursiveMode::NonRecursive); + Some(w) + } + Err(e) => { + let _ = tx.send(BranchEvent::WatcherError(format!( + "Failed to create watcher: {e}" + ))); + None + } + }; + + Self { + _watcher: watcher, + receiver: rx, + } + } + + pub async fn next_event(&mut self) -> Option { + self.receiver.recv().await + } + + pub fn try_next_event(&mut self) -> Option { + self.receiver.try_recv().ok() + } +} + +fn extract_branch_name(path: &Path, refs_dir: &Path) -> Option { + path.strip_prefix(refs_dir) + .ok() + .map(|rel| rel.to_string_lossy().to_string()) +} + +pub fn detect_force_push( + repo_path: &Path, + old_oid: &str, + new_oid: &str, +) -> bool { + let repo = match git2::Repository::discover(repo_path) { + Ok(r) => r, + Err(_) => return false, + }; + + let old = match git2::Oid::from_str(old_oid) { + Ok(o) => o, + Err(_) => return false, + }; + + let new = match git2::Oid::from_str(new_oid) { + Ok(o) => o, + Err(_) => return false, + }; + + if old == new { + return false; + } + + !repo.graph_descendant_of(new, old).unwrap_or(true) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn extract_branch_name_from_path() { + let refs_dir = PathBuf::from("/repo/.git/refs/heads"); + let path = PathBuf::from("/repo/.git/refs/heads/main"); + assert_eq!( + extract_branch_name(&path, &refs_dir), + Some("main".to_string()) + ); + } + + #[test] + fn extract_nested_branch_name() { + let refs_dir = PathBuf::from("/repo/.git/refs/heads"); + let path = PathBuf::from("/repo/.git/refs/heads/feature/login"); + assert_eq!( + extract_branch_name(&path, &refs_dir), + Some("feature/login".to_string()) + ); + } + + #[test] + fn extract_returns_none_for_unrelated_path() { + let refs_dir = PathBuf::from("/repo/.git/refs/heads"); + let path = PathBuf::from("/other/path"); + assert_eq!(extract_branch_name(&path, &refs_dir), None); + } + + #[test] + fn force_push_detection_on_nonexistent_repo() { + assert!(!detect_force_push(Path::new("/nonexistent"), "abc", "def")); + } + + #[test] + fn force_push_same_oid_is_not_force() { + let tmp = tempfile::tempdir().unwrap(); + git2::Repository::init(tmp.path()).unwrap(); + assert!(!detect_force_push(tmp.path(), "abc", "abc")); + } +} diff --git a/crates/cherrypick_ui/Cargo.toml b/crates/cherrypick_ui/Cargo.toml new file mode 100644 index 00000000000000..f64d61222872a0 --- /dev/null +++ b/crates/cherrypick_ui/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "cherrypick_ui" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +name = "cherrypick_ui" +path = "src/lib.rs" + +[dependencies] +gpui.workspace = true +workspace.workspace = true +git.workspace = true +git_ui.workspace = true +project.workspace = true +ui.workspace = true +settings.workspace = true +askpass.workspace = true +anyhow.workspace = true +cherrypick_pr.workspace = true +cherrypick_agent.workspace = true +git_graph.workspace = true +log.workspace = true +sha2.workspace = true +git2.workspace = true +tokio = { workspace = true, features = ["sync"] } + +[dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/cherrypick_ui/src/branch_list.rs b/crates/cherrypick_ui/src/branch_list.rs new file mode 100644 index 00000000000000..92dd493195454c --- /dev/null +++ b/crates/cherrypick_ui/src/branch_list.rs @@ -0,0 +1,177 @@ +use git::repository::Branch; +use gpui::{ + AnyElement, App, Context, Div, Entity, FocusHandle, Focusable, FontWeight, IntoElement, Render, + SharedString, Window, div, px, +}; +use project::git_store::Repository; +use ui::prelude::*; + +pub struct BranchList { + focus_handle: FocusHandle, + repository: Option>, + branches: Vec, + show_local: bool, + show_remote: bool, +} + +impl BranchList { + pub fn new(repository: Option>, cx: &mut Context) -> Self { + Self { + focus_handle: cx.focus_handle(), + repository, + branches: Vec::new(), + show_local: true, + show_remote: true, + } + } + + pub fn set_repository(&mut self, repo: Option>, cx: &mut Context) { + self.repository = repo; + cx.notify(); + } + + pub fn set_branches(&mut self, branches: Vec, cx: &mut Context) { + self.branches = branches; + self.branches.sort_by(|a, b| { + b.priority_key() + .partial_cmp(&a.priority_key()) + .unwrap_or(std::cmp::Ordering::Equal) + }); + cx.notify(); + } + + pub fn local_branch_names(&self) -> Vec { + self.branches + .iter() + .filter(|b| !b.is_remote()) + .map(|b| b.name().to_string()) + .collect() + } + + fn local_branches(&self) -> Vec<&Branch> { + self.branches.iter().filter(|b| !b.is_remote()).collect() + } + + fn remote_branches(&self) -> Vec<&Branch> { + self.branches.iter().filter(|b| b.is_remote()).collect() + } + + fn checkout_branch(&self, branch_name: String, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + let rx = repo.update(cx, |repo, _cx| repo.change_branch(branch_name)); + cx.spawn(async move |_this, _cx| { + let _ = rx.await; + }) + .detach(); + } + + fn render_branch_item( + &self, + branch: &Branch, + cx: &mut Context, + ) -> AnyElement { + let name = branch.name().to_string(); + let is_head = branch.is_head; + let tracking = branch.tracking_status(); + let checkout_name = name.clone(); + + let mut row = div() + .id(SharedString::from(format!("branch-{}", &name))) + .flex() + .items_center() + .gap_1() + .px_2() + .py(px(3.0)) + .rounded_sm() + .cursor_pointer() + .hover(|style| style.bg(cx.theme().colors().ghost_element_hover)) + .active(|style| style.bg(cx.theme().colors().ghost_element_active)) + .on_click(cx.listener(move |this, _event, _window, cx| { + this.checkout_branch(checkout_name.clone(), cx); + })); + + if is_head { + row = row.child( + div() + .text_xs() + .text_color(cx.theme().colors().text_accent) + .child("●"), + ); + } + + let mut name_el = div() + .flex_grow() + .text_sm() + .overflow_x_hidden() + .whitespace_nowrap(); + if is_head { + name_el = name_el.font_weight(FontWeight::SEMIBOLD); + } + name_el = name_el.child(name); + row = row.child(name_el); + + if let Some(status) = tracking { + row = row.child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(format!("↑{} ↓{}", status.ahead, status.behind)), + ); + } + + row.into_any_element() + } + + fn render_section_items( + &self, + label: &str, + branches: &[&Branch], + expanded: bool, + cx: &mut Context, + ) -> Div { + if !expanded || branches.is_empty() { + return div(); + } + + let header = div() + .text_xs() + .font_weight(FontWeight::SEMIBOLD) + .text_color(cx.theme().colors().text_muted) + .px_2() + .py_1() + .child(format!("{} ({})", label, branches.len())); + + let mut container = div().flex().flex_col().child(header); + for branch in branches { + container = container.child(self.render_branch_item(branch, cx)); + } + container + } +} + +impl Focusable for BranchList { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for BranchList { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let local = self.local_branches(); + let remote = self.remote_branches(); + + let mut root = div() + .id("branch-list") + .flex() + .flex_col() + .size_full() + .overflow_y_scroll() + .py_1(); + + root = root.child(self.render_section_items("Local", &local, self.show_local, cx)); + root = root.child(self.render_section_items("Remote", &remote, self.show_remote, cx)); + root + } +} diff --git a/crates/cherrypick_ui/src/cherrypick_sidebar.rs b/crates/cherrypick_ui/src/cherrypick_sidebar.rs new file mode 100644 index 00000000000000..3bc09dda232e7f --- /dev/null +++ b/crates/cherrypick_ui/src/cherrypick_sidebar.rs @@ -0,0 +1,445 @@ +use cherrypick_pr::LocalPr; +use gpui::{ + Action, App, Context, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, IntoElement, + Pixels, Render, SharedString, Subscription, WeakEntity, Window, actions, div, px, +}; +use project::Project; +use project::git_store::{GitStoreEvent, Repository, RepositoryEvent}; +use ui::prelude::*; +use workspace::{ + Workspace, + dock::{DockPosition, Panel, PanelEvent}, +}; + +use crate::branch_list::BranchList; +use crate::cherrypick_view::CherryPickView; +use crate::commit_graph_embed; +use crate::create_pr_form::{CreatePrForm, CreatePrFormEvent}; +use crate::pr_list::{PrList, PrListEvent}; +use crate::pr_state::PrState; + +actions!(cherrypick, [ToggleSidebar, OpenCherryPick]); + +const CHERRYPICK_SIDEBAR_KEY: &str = "CherryPickSidebar"; +const DEFAULT_WIDTH: f32 = 260.0; + +pub struct CherryPickSidebar { + focus_handle: FocusHandle, + workspace: WeakEntity, + project: Entity, + active_repository: Option>, + branch_list: Entity, + pr_list: Entity, + create_pr_form: Entity, + pr_state: PrState, + width: Option, + _subscriptions: Vec, +} + +impl CherryPickSidebar { + pub async fn load( + workspace: gpui::WeakEntity, + mut cx: gpui::AsyncWindowContext, + ) -> anyhow::Result> { + workspace.update_in(&mut cx, |workspace, window, cx| { + CherryPickSidebar::new(workspace, window, cx) + }) + } + + pub fn new( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + let project = workspace.project().clone(); + let git_store = project.read(cx).git_store().clone(); + let active_repository = project.read(cx).active_repository(cx); + + let weak_workspace = workspace.weak_handle(); + + cx.new(|cx| { + let branch_list = cx.new(|cx| BranchList::new(active_repository.clone(), cx)); + let pr_list = cx.new(|cx| PrList::new(cx)); + let create_pr_form = cx.new(|cx| CreatePrForm::new(cx)); + + let mut pr_state = PrState::new(cx); + if let Some(repo) = &active_repository { + let path = repo.read(cx).work_directory_abs_path.clone(); + pr_state.set_repo_path(path.to_path_buf()); + pr_state.initialize().detach(); + } + + let git_sub = cx.subscribe_in( + &git_store, + window, + |this: &mut Self, _git_store, event, window, cx| match event { + GitStoreEvent::ActiveRepositoryChanged(_) => { + this.update_active_repository(window, cx); + } + GitStoreEvent::RepositoryUpdated( + _, + RepositoryEvent::BranchChanged | RepositoryEvent::StatusesChanged, + true, + ) => { + this.refresh_branches(window, cx); + this.refresh_prs(cx); + } + GitStoreEvent::RepositoryAdded | GitStoreEvent::RepositoryRemoved(_) => { + this.update_active_repository(window, cx); + } + _ => {} + }, + ); + + let form_sub = cx.subscribe( + &create_pr_form, + |this: &mut Self, _form, event: &CreatePrFormEvent, cx| match event { + CreatePrFormEvent::Submit { + title, + source_branch, + target_branch, + } => { + this.handle_create_pr( + title.clone(), + source_branch.clone(), + target_branch.clone(), + cx, + ); + } + }, + ); + + let pr_list_sub = cx.subscribe_in( + &pr_list, + window, + |this: &mut Self, _list, event: &PrListEvent, window, cx| match event { + PrListEvent::Selected(pr) => { + this.open_pr_view(pr.clone(), window, cx); + } + }, + ); + + let mut sidebar = Self { + focus_handle: cx.focus_handle(), + workspace: weak_workspace, + project, + active_repository, + branch_list, + pr_list, + create_pr_form, + pr_state, + width: None, + _subscriptions: vec![git_sub, form_sub, pr_list_sub], + }; + + sidebar.refresh_prs(cx); + sidebar + }) + } + + fn update_active_repository(&mut self, window: &mut Window, cx: &mut Context) { + self.active_repository = self.project.read(cx).active_repository(cx); + self.branch_list.update(cx, |list, cx| { + list.set_repository(self.active_repository.clone(), cx); + }); + + if let Some(repo) = &self.active_repository { + let path = repo.read(cx).work_directory_abs_path.clone(); + self.pr_state.set_repo_path(path.to_path_buf()); + self.pr_state.initialize().detach(); + } + + self.refresh_branches(window, cx); + self.refresh_prs(cx); + } + + fn refresh_branches(&mut self, _window: &mut Window, cx: &mut Context) { + let Some(repo) = self.active_repository.clone() else { + self.branch_list.update(cx, |list, cx| { + list.set_branches(Vec::new(), cx); + }); + return; + }; + + let rx = repo.update(cx, |repo, _cx| repo.branches()); + let branch_list = self.branch_list.clone(); + let create_form = self.create_pr_form.clone(); + let current_branch = self.current_branch_name(cx); + cx.spawn(async move |_this, cx| { + if let Ok(Ok(branches)) = rx.await { + let mut names = Vec::new(); + let _ = branch_list.update(cx, |list, cx| { + list.set_branches(branches, cx); + names = list.local_branch_names(); + }); + if !names.is_empty() { + let _ = create_form.update(cx, |form, cx| { + form.set_branches(names, current_branch.as_deref(), cx); + }); + } + } + }) + .detach(); + } + + fn refresh_prs(&mut self, cx: &mut Context) { + if !self.pr_state.is_initialized() { + return; + } + + let task = self.pr_state.list_open_prs(); + let pr_list = self.pr_list.clone(); + cx.spawn(async move |_this, cx| { + if let Ok(prs) = task.await { + let _ = pr_list.update(cx, |list, cx| { + list.set_prs(prs, cx); + }); + } + }) + .detach(); + } + + fn toggle_create_form(&mut self, cx: &mut Context) { + self.create_pr_form.update(cx, |form, cx| { + form.toggle_visible(cx); + }); + } + + fn handle_create_pr(&mut self, title: String, source: String, target: String, cx: &mut Context) { + log::info!("cherrypick: handle_create_pr called: '{}' {} → {}", title, source, target); + + if !self.pr_state.is_initialized() { + log::error!("cherrypick: PR state not initialized, attempting init"); + self.pr_state.initialize().detach(); + return; + } + + let task = self.pr_state.create_pr(title.clone(), source.clone(), target.clone()); + let pr_list = self.pr_list.clone(); + let create_form = self.create_pr_form.clone(); + + cx.spawn(async move |this, cx| { + log::info!("cherrypick: creating PR async..."); + match task.await { + Ok(pr) => { + log::info!("cherrypick: PR created successfully: id={}, '{}'", pr.id, pr.title); + let _ = create_form.update(cx, |form, cx| { + form.set_visible(false, cx); + }); + let _ = this.update(cx, |this, cx| { + this.refresh_prs(cx); + }); + } + Err(e) => { + log::error!("cherrypick: failed to create PR: {}", e); + } + } + }) + .detach(); + } + + fn open_pr_view(&mut self, pr: LocalPr, window: &mut Window, cx: &mut Context) { + log::info!("cherrypick: opening PR view for: {}", pr.title); + let workspace = self.workspace.clone(); + let _ = workspace.update(cx, |workspace, cx| { + CherryPickView::deploy_with_pr(workspace, pr, window, cx); + }); + } + + fn open_cherrypick_view(&self, window: &mut Window, cx: &mut Context) { + let workspace = self.workspace.clone(); + let _ = workspace.update(cx, |workspace, cx| { + CherryPickView::deploy(workspace, window, cx); + }); + } + + fn open_commit_graph(&self, window: &mut Window, cx: &mut Context) { + let workspace = self.workspace.clone(); + let _ = workspace.update(cx, |workspace, cx| { + commit_graph_embed::open_git_graph(workspace, window, cx); + }); + } + + fn current_branch_name(&self, cx: &App) -> Option { + self.active_repository.as_ref().and_then(|repo| { + repo.read(cx).branch.as_ref().map(|b| b.name().to_string()) + }) + } +} + +pub fn register(workspace: &mut Workspace) { + workspace.register_action(|workspace, _: &ToggleSidebar, window, cx| { + workspace.toggle_panel_focus::(window, cx); + }); + workspace.register_action(|workspace, _: &OpenCherryPick, window, cx| { + CherryPickView::deploy(workspace, window, cx); + }); +} + +impl Focusable for CherryPickSidebar { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for CherryPickSidebar {} + +impl Render for CherryPickSidebar { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let current_branch = self.current_branch_name(cx); + + div() + .flex() + .flex_col() + .key_context("CherryPickSidebar") + .track_focus(&self.focus_handle) + .size_full() + .child( + div() + .flex() + .items_center() + .justify_between() + .px_2() + .py_1() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + div() + .text_sm() + .font_weight(FontWeight::SEMIBOLD) + .child("CherryPick"), + ) + .child( + div() + .flex() + .items_center() + .gap_1() + .child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child( + current_branch + .unwrap_or_else(|| "no branch".to_string()), + ), + ) + .child( + div() + .id("open-graph-btn") + .text_xs() + .cursor_pointer() + .px_1() + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| { + s.bg(cx.theme().colors().ghost_element_hover) + }) + .on_click(cx.listener(|this, _event, window, cx| { + this.open_commit_graph(window, cx); + })) + .child("Graph"), + ) + .child( + div() + .id("open-view-btn") + .text_xs() + .cursor_pointer() + .px_1() + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| { + s.bg(cx.theme().colors().ghost_element_hover) + }) + .on_click(cx.listener(|this, _event, window, cx| { + this.open_cherrypick_view(window, cx); + })) + .child("Open"), + ), + ), + ) + .child(self.create_pr_form.clone()) + .child(self.branch_list.clone()) + .child( + div() + .border_t_1() + .border_color(cx.theme().colors().border) + .child( + div() + .flex() + .items_center() + .justify_between() + .px_2() + .py_1() + .child(self.pr_list.clone()) + .child( + div() + .id("new-pr-btn") + .text_xs() + .cursor_pointer() + .px_1() + .rounded_sm() + .text_color(cx.theme().colors().text_accent) + .hover(|s| { + s.bg(cx.theme().colors().ghost_element_hover) + }) + .on_click(cx.listener(|this, _event, _window, cx| { + this.toggle_create_form(cx); + })) + .child("+ New PR"), + ), + ), + ) + } +} + +impl Panel for CherryPickSidebar { + fn persistent_name() -> &'static str { + CHERRYPICK_SIDEBAR_KEY + } + + fn panel_key() -> &'static str { + CHERRYPICK_SIDEBAR_KEY + } + + fn position(&self, _window: &Window, _cx: &App) -> DockPosition { + DockPosition::Right + } + + fn position_is_valid(&self, position: DockPosition) -> bool { + matches!(position, DockPosition::Left | DockPosition::Right) + } + + fn set_position( + &mut self, + _position: DockPosition, + _window: &mut Window, + cx: &mut Context, + ) { + cx.notify(); + } + + fn size(&self, _window: &Window, _cx: &App) -> Pixels { + self.width.unwrap_or(px(DEFAULT_WIDTH)) + } + + fn set_size(&mut self, size: Option, _window: &mut Window, cx: &mut Context) { + self.width = size; + cx.notify(); + } + + fn icon(&self, _window: &Window, _cx: &App) -> Option { + Some(ui::IconName::GitBranchAlt) + } + + fn icon_tooltip(&self, _window: &Window, _cx: &App) -> Option<&'static str> { + Some("CherryPick Sidebar") + } + + fn toggle_action(&self) -> Box { + Box::new(ToggleSidebar) + } + + fn activation_priority(&self) -> u32 { + 3 + } +} diff --git a/crates/cherrypick_ui/src/cherrypick_view.rs b/crates/cherrypick_ui/src/cherrypick_view.rs new file mode 100644 index 00000000000000..1dcc0589125c92 --- /dev/null +++ b/crates/cherrypick_ui/src/cherrypick_view.rs @@ -0,0 +1,487 @@ +use cherrypick_pr::diff_service::BranchDiff; +use cherrypick_pr::LocalPr; +use gpui::{ + App, Context, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, IntoElement, + SharedString, Subscription, WeakEntity, Window, div, +}; +use project::Project; +use project::git_store::{GitStoreEvent, Repository, RepositoryEvent}; +use ui::prelude::*; +use workspace::{Item, Workspace, item::ItemEvent}; + +use crate::diff_file_list::DiffFileList; +use crate::pr_state::PrState; +use crate::staging_view::StagingView; + +pub enum ViewMode { + Staging, + PrDetail(LocalPr), + BranchCompare { + source_branch: String, + target_branch: String, + source_oid: String, + target_oid: String, + }, +} + +pub enum CherryPickViewEvent { + UpdateTab, +} + +pub struct CherryPickView { + focus_handle: FocusHandle, + workspace: WeakEntity, + project: Entity, + active_repository: Option>, + mode: ViewMode, + staging_view: Entity, + diff_file_list: Entity, + unified_diff_text: Option, + pr_state: PrState, + _subscriptions: Vec, +} + +impl CherryPickView { + pub fn new( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let project = workspace.project().clone(); + let git_store = project.read(cx).git_store().clone(); + let active_repository = project.read(cx).active_repository(cx); + + let staging_view = cx.new(|cx| StagingView::new(active_repository.clone(), cx)); + let diff_file_list = cx.new(|cx| DiffFileList::new(cx)); + + let mut pr_state = PrState::new(cx); + if let Some(repo) = &active_repository { + let path = repo.read(cx).work_directory_abs_path.clone(); + pr_state.set_repo_path(path.to_path_buf()); + pr_state.initialize().detach(); + } + + let git_sub = cx.subscribe_in( + &git_store, + window, + |this: &mut Self, _git_store, event, _window, cx| match event { + GitStoreEvent::ActiveRepositoryChanged(_) => { + this.update_active_repository(cx); + } + GitStoreEvent::RepositoryUpdated( + _, + RepositoryEvent::StatusesChanged | RepositoryEvent::BranchChanged, + true, + ) => { + this.refresh(cx); + } + _ => {} + }, + ); + + Self { + focus_handle: cx.focus_handle(), + workspace: workspace.weak_handle(), + project, + active_repository, + mode: ViewMode::Staging, + staging_view, + diff_file_list, + unified_diff_text: None, + pr_state, + _subscriptions: vec![git_sub], + } + } + + pub fn deploy( + workspace: &mut Workspace, + window: &mut Window, + cx: &mut Context, + ) { + let existing = workspace + .active_pane() + .read(cx) + .items() + .find_map(|item| item.downcast::()); + + if let Some(existing) = existing { + workspace.activate_item(&existing, true, true, window, cx); + } else { + let view = cx.new(|cx| CherryPickView::new(workspace, window, cx)); + workspace.add_item_to_active_pane(Box::new(view), None, true, window, cx); + } + } + + pub fn deploy_with_pr( + workspace: &mut Workspace, + pr: LocalPr, + window: &mut Window, + cx: &mut Context, + ) { + let view = cx.new(|cx| { + let mut v = CherryPickView::new(workspace, window, cx); + v.show_pr(pr, cx); + v + }); + workspace.add_item_to_active_pane(Box::new(view), None, true, window, cx); + } + + pub fn deploy_compare( + workspace: &mut Workspace, + source_branch: String, + target_branch: String, + source_oid: String, + target_oid: String, + window: &mut Window, + cx: &mut Context, + ) { + let view = cx.new(|cx| { + let mut v = CherryPickView::new(workspace, window, cx); + v.show_compare(source_branch, target_branch, source_oid, target_oid, cx); + v + }); + workspace.add_item_to_active_pane(Box::new(view), None, true, window, cx); + } + + fn show_pr(&mut self, pr: LocalPr, cx: &mut Context) { + let source_oid = pr.source_oid.clone(); + let target_oid = pr.target_oid.clone(); + self.mode = ViewMode::PrDetail(pr); + self.load_diff(source_oid, target_oid, cx); + cx.emit(CherryPickViewEvent::UpdateTab); + cx.notify(); + } + + fn show_compare( + &mut self, + source_branch: String, + target_branch: String, + source_oid: String, + target_oid: String, + cx: &mut Context, + ) { + self.mode = ViewMode::BranchCompare { + source_branch, + target_branch, + source_oid: source_oid.clone(), + target_oid: target_oid.clone(), + }; + self.load_diff(source_oid, target_oid, cx); + cx.emit(CherryPickViewEvent::UpdateTab); + cx.notify(); + } + + fn load_diff(&mut self, source_oid: String, target_oid: String, cx: &mut Context) { + let file_task = self.pr_state.get_branch_diff(source_oid.clone(), target_oid.clone()); + let text_task = self.pr_state.get_unified_diff(source_oid, target_oid); + let diff_list = self.diff_file_list.clone(); + + cx.spawn(async move |this, cx| { + if let Ok(diff) = file_task.await { + let _ = diff_list.update(cx, |list, cx| { + list.set_diff(Some(diff), cx); + }); + } + if let Ok(text) = text_task.await { + let _ = this.update(cx, |this, cx| { + this.unified_diff_text = Some(text); + cx.notify(); + }); + } + }) + .detach(); + } + + fn update_active_repository(&mut self, cx: &mut Context) { + self.active_repository = self.project.read(cx).active_repository(cx); + self.staging_view.update(cx, |view, cx| { + view.set_repository(self.active_repository.clone(), cx); + }); + + if let Some(repo) = &self.active_repository { + let path = repo.read(cx).work_directory_abs_path.clone(); + self.pr_state.set_repo_path(path.to_path_buf()); + self.pr_state.initialize().detach(); + } + cx.notify(); + } + + fn refresh(&mut self, cx: &mut Context) { + self.staging_view.update(cx, |view, cx| { + view.set_repository(self.active_repository.clone(), cx); + }); + cx.notify(); + } + + fn tab_title(&self) -> String { + match &self.mode { + ViewMode::Staging => "CherryPick".to_string(), + ViewMode::PrDetail(pr) => format!("PR: {}", pr.title), + ViewMode::BranchCompare { + source_branch, + target_branch, + .. + } => format!("{} → {}", source_branch, target_branch), + } + } + + fn render_header(&self, cx: &App) -> gpui::Div { + match &self.mode { + ViewMode::Staging => self.render_staging_header(cx), + ViewMode::PrDetail(pr) => self.render_pr_header(pr, cx), + ViewMode::BranchCompare { + source_branch, + target_branch, + .. + } => self.render_compare_header(source_branch, target_branch, cx), + } + } + + fn render_staging_header(&self, cx: &App) -> gpui::Div { + let Some(repo) = self.active_repository.as_ref() else { + return div() + .p_2() + .text_color(cx.theme().colors().text_muted) + .child("No repository"); + }; + + let snapshot = repo.read(cx); + let branch_name = snapshot + .branch + .as_ref() + .map(|b| b.name().to_string()) + .unwrap_or_else(|| "detached HEAD".to_string()); + + let head_info = snapshot + .head_commit + .as_ref() + .map(|c| { + format!( + "{} {}", + &c.sha[..7.min(c.sha.len())], + c.message.lines().next().unwrap_or("") + ) + }) + .unwrap_or_else(|| "no commits".to_string()); + + div() + .flex() + .flex_col() + .gap_1() + .p_2() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + div() + .text_sm() + .font_weight(FontWeight::SEMIBOLD) + .child(branch_name), + ) + .child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(head_info), + ) + } + + fn render_pr_header(&self, pr: &LocalPr, cx: &App) -> gpui::Div { + let status_label = pr.status.as_str(); + let status_color = match pr.status { + cherrypick_pr::PrStatus::Open => cx.theme().colors().version_control_added, + cherrypick_pr::PrStatus::Merged => cx.theme().colors().version_control_modified, + cherrypick_pr::PrStatus::Closed => cx.theme().colors().version_control_deleted, + }; + + div() + .flex() + .flex_col() + .gap_1() + .p_2() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + div() + .flex() + .items_center() + .gap_2() + .child( + div() + .text_sm() + .font_weight(FontWeight::SEMIBOLD) + .child(pr.title.clone()), + ) + .child( + div() + .text_xs() + .px_1() + .rounded_sm() + .text_color(status_color) + .child(status_label), + ), + ) + .child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(format!( + "{} → {}", + pr.source_branch, pr.target_branch + )), + ) + } + + fn render_compare_header( + &self, + source: &str, + target: &str, + cx: &App, + ) -> gpui::Div { + div() + .flex() + .items_center() + .gap_2() + .p_2() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + div() + .text_sm() + .font_weight(FontWeight::SEMIBOLD) + .child("Comparing"), + ) + .child( + div() + .text_sm() + .text_color(cx.theme().colors().text_accent) + .child(source.to_string()), + ) + .child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child("→"), + ) + .child( + div() + .text_sm() + .text_color(cx.theme().colors().text_accent) + .child(target.to_string()), + ) + } + + fn render_unified_diff(&self, cx: &mut Context) -> gpui::Stateful { + let Some(diff_text) = &self.unified_diff_text else { + return div() + .id("unified-diff-empty") + .p_2() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child("Loading diff..."); + }; + + let added_color = cx.theme().colors().version_control_added; + let deleted_color = cx.theme().colors().version_control_deleted; + let muted = cx.theme().colors().text_muted; + let text_color = cx.theme().colors().text; + + let mut container = div() + .id("unified-diff") + .flex() + .flex_col() + .flex_grow() + .overflow_y_scroll() + .p_1() + .text_xs() + .font_family("monospace"); + + for line in diff_text.lines() { + let color = if line.starts_with('+') && !line.starts_with("+++") { + added_color + } else if line.starts_with('-') && !line.starts_with("---") { + deleted_color + } else if line.starts_with("@@") { + muted + } else if line.starts_with("diff ") || line.starts_with("index ") { + muted + } else { + text_color + }; + + let bg = if line.starts_with('+') && !line.starts_with("+++") { + Some(cx.theme().colors().version_control_added.opacity(0.1)) + } else if line.starts_with('-') && !line.starts_with("---") { + Some(cx.theme().colors().version_control_deleted.opacity(0.1)) + } else { + None + }; + + let mut line_div = div() + .text_color(color) + .px_1() + .child(line.to_string()); + + if let Some(bg_color) = bg { + line_div = line_div.bg(bg_color); + } + + container = container.child(line_div); + } + + container + } +} + +impl Focusable for CherryPickView { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl EventEmitter for CherryPickView {} + +impl Render for CherryPickView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let header = self.render_header(cx); + + let mut root = div() + .key_context("CherryPickView") + .track_focus(&self.focus_handle) + .size_full() + .flex() + .flex_col() + .child(header); + + match &self.mode { + ViewMode::Staging => { + root = root.child( + div().flex_grow().child(self.staging_view.clone()), + ); + } + ViewMode::PrDetail(_) | ViewMode::BranchCompare { .. } => { + root = root.child(self.diff_file_list.clone()); + root = root.child(self.render_unified_diff(cx)); + } + } + + root + } +} + +impl Item for CherryPickView { + type Event = CherryPickViewEvent; + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + SharedString::from(self.tab_title()) + } + + fn tab_tooltip_text(&self, _cx: &App) -> Option { + Some("CherryPick Git Client".into()) + } + + fn to_item_events(event: &Self::Event, f: &mut dyn FnMut(ItemEvent)) { + match event { + CherryPickViewEvent::UpdateTab => f(ItemEvent::UpdateTab), + } + } +} diff --git a/crates/cherrypick_ui/src/commit_graph_embed.rs b/crates/cherrypick_ui/src/commit_graph_embed.rs new file mode 100644 index 00000000000000..2f9d2a14ad2260 --- /dev/null +++ b/crates/cherrypick_ui/src/commit_graph_embed.rs @@ -0,0 +1,21 @@ +use git_graph::GitGraph; +use gpui::AppContext as _; +use workspace::Workspace; + +pub fn open_git_graph(workspace: &mut Workspace, window: &mut gpui::Window, cx: &mut gpui::Context) { + let existing = workspace.items_of_type::(cx).next(); + if let Some(existing) = existing { + workspace.activate_item(&existing, true, true, window, cx); + return; + } + + let project = workspace.project().clone(); + if project.read(cx).active_repository(cx).is_none() { + log::warn!("cherrypick: no active repository for git graph"); + return; + } + + let workspace_handle = workspace.weak_handle(); + let git_graph = cx.new(|cx| GitGraph::new(project, workspace_handle, window, cx)); + workspace.add_item_to_active_pane(Box::new(git_graph), None, true, window, cx); +} diff --git a/crates/cherrypick_ui/src/create_pr_form.rs b/crates/cherrypick_ui/src/create_pr_form.rs new file mode 100644 index 00000000000000..a4a7ac5bcf7078 --- /dev/null +++ b/crates/cherrypick_ui/src/create_pr_form.rs @@ -0,0 +1,258 @@ +use gpui::{ + App, Context, EventEmitter, FocusHandle, Focusable, FontWeight, IntoElement, Render, + SharedString, Window, div, px, +}; +use ui::prelude::*; + +pub enum CreatePrFormEvent { + Submit { + title: String, + source_branch: String, + target_branch: String, + }, +} + +pub struct CreatePrForm { + focus_handle: FocusHandle, + visible: bool, + branch_names: Vec, + selected_source: usize, + selected_target: usize, + title: String, +} + +impl CreatePrForm { + pub fn new(cx: &mut Context) -> Self { + Self { + focus_handle: cx.focus_handle(), + visible: false, + branch_names: Vec::new(), + selected_source: 0, + selected_target: 0, + title: String::new(), + } + } + + pub fn toggle_visible(&mut self, cx: &mut Context) { + self.visible = !self.visible; + cx.notify(); + } + + pub fn set_visible(&mut self, visible: bool, cx: &mut Context) { + self.visible = visible; + cx.notify(); + } + + pub fn set_branches(&mut self, names: Vec, current_branch: Option<&str>, cx: &mut Context) { + self.branch_names = names; + if self.branch_names.is_empty() { + cx.notify(); + return; + } + + self.selected_source = 0; + self.selected_target = 0; + + if let Some(current) = current_branch { + if let Some(idx) = self.branch_names.iter().position(|b| b == current) { + self.selected_source = idx; + } + } + + if let Some(idx) = self.branch_names.iter().position(|b| b == "main") { + self.selected_target = idx; + } else if let Some(idx) = self.branch_names.iter().position(|b| b == "master") { + self.selected_target = idx; + } + + if self.selected_source == self.selected_target && self.branch_names.len() > 1 { + for (i, name) in self.branch_names.iter().enumerate() { + if i != self.selected_source { + self.selected_target = i; + break; + } + } + } + + self.title = self.branch_names.get(self.selected_source) + .cloned() + .unwrap_or_default(); + + cx.notify(); + } + + fn source_name(&self) -> &str { + self.branch_names + .get(self.selected_source) + .map(|s| s.as_str()) + .unwrap_or("(none)") + } + + fn target_name(&self) -> &str { + self.branch_names + .get(self.selected_target) + .map(|s| s.as_str()) + .unwrap_or("(none)") + } + + fn cycle_source(&mut self, cx: &mut Context) { + if !self.branch_names.is_empty() { + self.selected_source = (self.selected_source + 1) % self.branch_names.len(); + cx.notify(); + } + } + + fn cycle_target(&mut self, cx: &mut Context) { + if !self.branch_names.is_empty() { + self.selected_target = (self.selected_target + 1) % self.branch_names.len(); + cx.notify(); + } + } + + fn submit(&mut self, cx: &mut Context) { + let Some(source) = self.branch_names.get(self.selected_source).cloned() else { + log::warn!("cherrypick: no source branch selected"); + return; + }; + let Some(target) = self.branch_names.get(self.selected_target).cloned() else { + log::warn!("cherrypick: no target branch selected"); + return; + }; + if source == target { + log::warn!("cherrypick: source and target are the same: {}", source); + return; + } + let title = if self.title.is_empty() { + source.clone() + } else { + self.title.clone() + }; + log::info!("cherrypick: submitting PR '{}': {} → {}", title, source, target); + cx.emit(CreatePrFormEvent::Submit { + title, + source_branch: source, + target_branch: target, + }); + } +} + +impl EventEmitter for CreatePrForm {} + +impl Focusable for CreatePrForm { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for CreatePrForm { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + if !self.visible { + return div(); + } + + let source = self.source_name().to_string(); + let target = self.target_name().to_string(); + + div() + .flex() + .flex_col() + .gap_1() + .px_2() + .py_2() + .bg(cx.theme().colors().surface_background) + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + div() + .text_xs() + .font_weight(FontWeight::SEMIBOLD) + .text_color(cx.theme().colors().text_muted) + .child("Create Local PR"), + ) + .child( + div() + .flex() + .items_center() + .gap_1() + .child( + div().text_xs().text_color(cx.theme().colors().text_muted).child("From:"), + ) + .child( + div() + .id("source-branch-picker") + .text_xs() + .cursor_pointer() + .px_1() + .py(px(2.0)) + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, _window, cx| { + this.cycle_source(cx); + })) + .child(source), + ), + ) + .child( + div() + .flex() + .items_center() + .gap_1() + .child( + div().text_xs().text_color(cx.theme().colors().text_muted).child("Into:"), + ) + .child( + div() + .id("target-branch-picker") + .text_xs() + .cursor_pointer() + .px_1() + .py(px(2.0)) + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, _window, cx| { + this.cycle_target(cx); + })) + .child(target), + ), + ) + .child( + div() + .flex() + .items_center() + .gap_1() + .pt_1() + .child( + div() + .id("create-pr-submit") + .text_xs() + .cursor_pointer() + .px_2() + .py(px(3.0)) + .rounded_sm() + .bg(cx.theme().colors().element_background) + .hover(|s| s.bg(cx.theme().colors().element_hover)) + .on_click(cx.listener(|this, _event, _window, cx| { + this.submit(cx); + })) + .child("Create PR"), + ) + .child( + div() + .id("create-pr-cancel") + .text_xs() + .cursor_pointer() + .px_2() + .py(px(3.0)) + .rounded_sm() + .text_color(cx.theme().colors().text_muted) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, _window, cx| { + this.set_visible(false, cx); + })) + .child("Cancel"), + ), + ) + } +} diff --git a/crates/cherrypick_ui/src/diff_file_list.rs b/crates/cherrypick_ui/src/diff_file_list.rs new file mode 100644 index 00000000000000..444c474d73d6ac --- /dev/null +++ b/crates/cherrypick_ui/src/diff_file_list.rs @@ -0,0 +1,166 @@ +use cherrypick_pr::diff_service::{BranchDiff, DiffFileEntry}; +use gpui::{ + AnyElement, App, Context, FocusHandle, Focusable, FontWeight, IntoElement, Render, SharedString, + Window, div, px, +}; +use ui::prelude::*; + +pub struct DiffFileList { + focus_handle: FocusHandle, + diff: Option, +} + +impl DiffFileList { + pub fn new(cx: &mut Context) -> Self { + Self { + focus_handle: cx.focus_handle(), + diff: None, + } + } + + pub fn set_diff(&mut self, diff: Option, cx: &mut Context) { + self.diff = diff; + cx.notify(); + } + + fn render_summary(&self, diff: &BranchDiff, cx: &mut Context) -> gpui::Div { + div() + .flex() + .items_center() + .gap_2() + .px_2() + .py_1() + .border_b_1() + .border_color(cx.theme().colors().border) + .child( + div() + .text_xs() + .font_weight(FontWeight::SEMIBOLD) + .child(format!("{} files changed", diff.files.len())), + ) + .child( + div() + .text_xs() + .text_color(cx.theme().colors().version_control_added) + .child(format!("+{}", diff.total_insertions)), + ) + .child( + div() + .text_xs() + .text_color(cx.theme().colors().version_control_deleted) + .child(format!("-{}", diff.total_deletions)), + ) + } + + fn render_file_entry(&self, entry: &DiffFileEntry, cx: &mut Context) -> AnyElement { + let status_color = match entry.status { + 'A' => cx.theme().colors().version_control_added, + 'D' => cx.theme().colors().version_control_deleted, + 'M' => cx.theme().colors().version_control_modified, + 'R' => cx.theme().colors().version_control_renamed, + 'C' => cx.theme().colors().version_control_modified, + 'T' => cx.theme().colors().version_control_modified, + _ => cx.theme().colors().text_muted, + }; + + let path = entry.path.clone(); + let display_path = if let Some(old) = &entry.old_path { + format!("{} → {}", old, path) + } else { + path.clone() + }; + + let ins = entry.insertions; + let del = entry.deletions; + let is_binary = entry.is_binary; + + div() + .id(SharedString::from(format!("diff-file-{}", &path))) + .flex() + .items_center() + .gap_1() + .px_2() + .py(px(3.0)) + .cursor_pointer() + .hover(|style| style.bg(cx.theme().colors().ghost_element_hover)) + .child( + div() + .text_xs() + .font_weight(FontWeight::BOLD) + .w(px(14.0)) + .text_color(status_color) + .child(format!("{}", entry.status)), + ) + .child( + div() + .flex_grow() + .text_sm() + .overflow_x_hidden() + .whitespace_nowrap() + .child(display_path), + ) + .when(!is_binary && (ins > 0 || del > 0), |el| { + el.child( + div() + .flex() + .gap(px(4.0)) + .text_xs() + .child( + div() + .text_color(cx.theme().colors().version_control_added) + .child(format!("+{}", ins)), + ) + .child( + div() + .text_color(cx.theme().colors().version_control_deleted) + .child(format!("-{}", del)), + ), + ) + }) + .when(is_binary, |el| { + el.child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child("binary"), + ) + }) + .into_any_element() + } +} + +impl Focusable for DiffFileList { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for DiffFileList { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + if self.diff.is_none() { + return div() + .id("diff-file-list-empty") + .p_2() + .text_sm() + .text_color(cx.theme().colors().text_muted) + .child("No diff to display"); + } + + let diff = self.diff.as_ref().unwrap(); + let summary = self.render_summary(diff, cx); + let file_entries: Vec<_> = diff + .files + .iter() + .map(|f| self.render_file_entry(f, cx)) + .collect(); + + div() + .id("diff-file-list") + .flex() + .flex_col() + .size_full() + .overflow_y_scroll() + .child(summary) + .children(file_entries) + } +} diff --git a/crates/cherrypick_ui/src/lib.rs b/crates/cherrypick_ui/src/lib.rs new file mode 100644 index 00000000000000..f6956a88a721e1 --- /dev/null +++ b/crates/cherrypick_ui/src/lib.rs @@ -0,0 +1,23 @@ +pub mod cherrypick_sidebar; +pub use cherrypick_sidebar::CherryPickSidebar; +pub(crate) mod cherrypick_view; +pub(crate) mod branch_list; +pub(crate) mod commit_graph_embed; +pub(crate) mod create_pr_form; +pub(crate) mod diff_file_list; +pub(crate) mod pr_detail; +pub(crate) mod pr_list; +pub(crate) mod pr_state; +pub(crate) mod staging_view; +mod worktree_list; +mod stash_list; +pub(crate) mod remote_toolbar; +mod repo_state_banner; +mod settings; + +pub fn init(cx: &mut gpui::App) { + cx.observe_new(|workspace: &mut workspace::Workspace, _window, _cx| { + cherrypick_sidebar::register(workspace); + }) + .detach(); +} diff --git a/crates/cherrypick_ui/src/pr_detail.rs b/crates/cherrypick_ui/src/pr_detail.rs new file mode 100644 index 00000000000000..6425f0b7c3ca6f --- /dev/null +++ b/crates/cherrypick_ui/src/pr_detail.rs @@ -0,0 +1,182 @@ +use cherrypick_pr::{LocalPr, PrStatus}; +use gpui::{ + App, Context, FocusHandle, Focusable, FontWeight, IntoElement, Render, SharedString, Window, + div, px, +}; +use ui::prelude::*; + +use crate::pr_state::PrState; + +pub enum PrAction { + Close(i64), + Reopen(i64), + StatusChanged, +} + +pub struct PrDetail { + focus_handle: FocusHandle, + pr: Option, + conflict_count: Option, + on_action: Option>, +} + +impl PrDetail { + pub fn new(cx: &mut Context) -> Self { + Self { + focus_handle: cx.focus_handle(), + pr: None, + conflict_count: None, + on_action: None, + } + } + + pub fn set_pr(&mut self, pr: Option, cx: &mut Context) { + self.pr = pr; + self.conflict_count = None; + cx.notify(); + } + + pub fn set_conflict_count(&mut self, count: usize, cx: &mut Context) { + self.conflict_count = Some(count); + cx.notify(); + } + + pub fn on_action( + &mut self, + callback: impl Fn(PrAction, &mut Window, &mut App) + Send + Sync + 'static, + ) { + self.on_action = Some(Box::new(callback)); + } + + fn close_pr(&mut self, window: &mut Window, cx: &mut Context) { + let Some(pr) = &self.pr else { return }; + let pr_id = pr.id; + if let Some(cb) = &self.on_action { + cb(PrAction::Close(pr_id), window, cx); + } + } + + fn reopen_pr(&mut self, window: &mut Window, cx: &mut Context) { + let Some(pr) = &self.pr else { return }; + let pr_id = pr.id; + if let Some(cb) = &self.on_action { + cb(PrAction::Reopen(pr_id), window, cx); + } + } +} + +impl Focusable for PrDetail { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for PrDetail { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let Some(pr) = &self.pr else { + return div(); + }; + + let status_color = match pr.status { + PrStatus::Open => cx.theme().colors().version_control_added, + PrStatus::Merged => cx.theme().colors().version_control_modified, + PrStatus::Closed => cx.theme().colors().version_control_deleted, + }; + + let is_open = pr.status == PrStatus::Open; + let is_closed = pr.status == PrStatus::Closed; + let conflict_count = self.conflict_count; + + let mut root = div() + .flex() + .flex_col() + .gap_1() + .px_2() + .py_1() + .border_b_1() + .border_color(cx.theme().colors().border); + + root = root.child( + div() + .flex() + .items_center() + .gap_2() + .child( + div() + .text_sm() + .font_weight(FontWeight::SEMIBOLD) + .child(pr.title.clone()), + ) + .child( + div() + .text_xs() + .px_1() + .py(px(1.0)) + .rounded_sm() + .text_color(status_color) + .border_1() + .border_color(status_color) + .child(pr.status.as_str()), + ), + ); + + root = root.child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(format!("{} → {}", pr.source_branch, pr.target_branch)), + ); + + if let Some(count) = conflict_count { + if count > 0 { + root = root.child( + div() + .text_xs() + .text_color(cx.theme().colors().version_control_conflict) + .child(format!("{} conflicts", count)), + ); + } + } + + let mut actions = div().flex().items_center().gap_1().pt_1(); + + if is_open { + actions = actions.child( + div() + .id("close-pr-btn") + .text_xs() + .cursor_pointer() + .px_2() + .py(px(2.0)) + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, window, cx| { + this.close_pr(window, cx); + })) + .child("Close PR"), + ); + } + + if is_closed { + actions = actions.child( + div() + .id("reopen-pr-btn") + .text_xs() + .cursor_pointer() + .px_2() + .py(px(2.0)) + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, window, cx| { + this.reopen_pr(window, cx); + })) + .child("Reopen PR"), + ); + } + + root = root.child(actions); + root + } +} diff --git a/crates/cherrypick_ui/src/pr_list.rs b/crates/cherrypick_ui/src/pr_list.rs new file mode 100644 index 00000000000000..8df15e00d1678f --- /dev/null +++ b/crates/cherrypick_ui/src/pr_list.rs @@ -0,0 +1,120 @@ +use cherrypick_pr::{LocalPr, PrStatus}; +use gpui::{ + AnyElement, App, Context, EventEmitter, FocusHandle, Focusable, FontWeight, IntoElement, + Render, SharedString, Window, div, px, +}; +use ui::prelude::*; + +pub enum PrListEvent { + Selected(LocalPr), +} + +pub struct PrList { + focus_handle: FocusHandle, + prs: Vec, +} + +impl PrList { + pub fn new(cx: &mut Context) -> Self { + Self { + focus_handle: cx.focus_handle(), + prs: Vec::new(), + } + } + + pub fn set_prs(&mut self, prs: Vec, cx: &mut Context) { + self.prs = prs; + cx.notify(); + } + + fn render_pr_item(&self, pr: &LocalPr, cx: &mut Context) -> AnyElement { + let pr_id = pr.id; + let title = pr.title.clone(); + let source = pr.source_branch.clone(); + let target = pr.target_branch.clone(); + let status = pr.status; + + let status_color = match status { + PrStatus::Open => cx.theme().colors().version_control_added, + PrStatus::Merged => cx.theme().colors().version_control_modified, + PrStatus::Closed => cx.theme().colors().version_control_deleted, + }; + + let pr_clone = pr.clone(); + div() + .id(SharedString::from(format!("pr-{}", pr_id))) + .flex() + .flex_col() + .gap(px(2.0)) + .px_2() + .py_1() + .rounded_sm() + .cursor_pointer() + .hover(|style| style.bg(cx.theme().colors().ghost_element_hover)) + .active(|style| style.bg(cx.theme().colors().ghost_element_active)) + .on_click(cx.listener(move |this, _event, _window, cx| { + log::info!("cherrypick: PR clicked: id={}", pr_clone.id); + cx.emit(PrListEvent::Selected(pr_clone.clone())); + })) + .child( + div() + .flex() + .items_center() + .gap_1() + .child( + div() + .w(px(6.0)) + .h(px(6.0)) + .rounded_full() + .bg(status_color), + ) + .child( + div() + .text_sm() + .overflow_x_hidden() + .whitespace_nowrap() + .child(title), + ), + ) + .child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(format!("{} → {}", source, target)), + ) + .into_any_element() + } +} + +impl EventEmitter for PrList {} + +impl Focusable for PrList { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for PrList { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + if self.prs.is_empty() { + return div() + .px_2() + .py_1() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child("No open PRs"); + } + + let header = div() + .text_xs() + .font_weight(FontWeight::SEMIBOLD) + .text_color(cx.theme().colors().text_muted) + .px_2() + .py_1() + .child(format!("PRs ({})", self.prs.len())); + + let items: Vec<_> = self.prs.iter().map(|pr| self.render_pr_item(pr, cx)).collect(); + + div().flex().flex_col().child(header).children(items) + } +} diff --git a/crates/cherrypick_ui/src/pr_state.rs b/crates/cherrypick_ui/src/pr_state.rs new file mode 100644 index 00000000000000..86fcbc96b5aca9 --- /dev/null +++ b/crates/cherrypick_ui/src/pr_state.rs @@ -0,0 +1,275 @@ +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; + +use cherrypick_pr::diff_service::{BranchDiff, DiffService}; +use cherrypick_pr::{LocalPr, PrStatus, PrStore}; +use gpui::{App, BackgroundExecutor, Task}; + +pub struct PrState { + store: Arc>>, + diff_service: Arc>, + repo_path: Option, + repo_id: Arc>>, + executor: BackgroundExecutor, +} + +impl PrState { + pub fn new(cx: &App) -> Self { + Self { + store: Arc::new(tokio::sync::Mutex::new(None)), + diff_service: Arc::new(Mutex::new(DiffService::new(32))), + repo_path: None, + repo_id: Arc::new(tokio::sync::Mutex::new(None)), + executor: cx.background_executor().clone(), + } + } + + pub fn set_repo_path(&mut self, path: PathBuf) { + self.repo_path = Some(path); + } + + pub fn repo_path(&self) -> Option<&Path> { + self.repo_path.as_deref() + } + + pub fn is_initialized(&self) -> bool { + self.repo_path.is_some() + } + + pub fn initialize(&self) -> Task> { + let Some(repo_path) = self.repo_path.clone() else { + return Task::ready(Err(anyhow::anyhow!("no repo path set"))); + }; + + let store_lock = self.store.clone(); + let repo_id_lock = self.repo_id.clone(); + + let git_info = std::thread::spawn(move || -> anyhow::Result<(String, String, String, String)> { + let db_dir = repo_path.join(".cherrypick"); + std::fs::create_dir_all(&db_dir)?; + let db_path = db_dir.join("prs.db").to_string_lossy().to_string(); + + let repo = git2::Repository::discover(&repo_path)?; + let mut revwalk = repo.revwalk()?; + revwalk.push_head()?; + revwalk.set_sorting(git2::Sort::TOPOLOGICAL | git2::Sort::REVERSE)?; + let first_oid = revwalk + .next() + .ok_or_else(|| anyhow::anyhow!("no commits"))?? + .to_string(); + + use sha2::Digest; + let hash = format!("{:x}", sha2::Sha256::digest(repo_path.to_string_lossy().as_bytes())); + let canonical = repo_path.to_string_lossy().to_string(); + + Ok((db_path, first_oid, hash, canonical)) + }); + + self.executor.spawn(async move { + let (db_path, first_oid, hash, canonical) = git_info + .join() + .map_err(|_| anyhow::anyhow!("git info thread panicked"))??; + + let store = PrStore::open(&db_path) + .await + .map_err(|e| anyhow::anyhow!("{e}"))?; + + let id = store + .ensure_repo(&first_oid, &hash, &canonical) + .await + .map_err(|e| anyhow::anyhow!("{e}"))?; + + *repo_id_lock.lock().await = Some(id); + *store_lock.lock().await = Some(store); + Ok(()) + }) + } + + pub fn list_open_prs(&self) -> Task>> { + let store_lock = self.store.clone(); + let repo_id_lock = self.repo_id.clone(); + + self.executor.spawn(async move { + let repo_id = repo_id_lock + .lock() + .await + .ok_or_else(|| anyhow::anyhow!("repo not registered"))?; + + let guard = store_lock.lock().await; + let store = guard + .as_ref() + .ok_or_else(|| anyhow::anyhow!("PR store not initialized"))?; + + store + .list_prs(repo_id, Some(PrStatus::Open)) + .await + .map_err(|e| anyhow::anyhow!("{e}")) + }) + } + + pub fn create_pr( + &self, + title: String, + source_branch: String, + target_branch: String, + ) -> Task> { + let Some(repo_path) = self.repo_path.clone() else { + return Task::ready(Err(anyhow::anyhow!("no repo path"))); + }; + let store_lock = self.store.clone(); + let repo_id_lock = self.repo_id.clone(); + + let sb = source_branch.clone(); + let tb = target_branch.clone(); + let git_info = std::thread::spawn(move || -> anyhow::Result<(String, String)> { + let repo = git2::Repository::discover(&repo_path)?; + let source_ref = repo + .find_branch(&sb, git2::BranchType::Local) + .map_err(|_| anyhow::anyhow!("branch '{}' not found", sb))?; + let source_oid = source_ref + .get() + .target() + .ok_or_else(|| anyhow::anyhow!("branch has no target"))? + .to_string(); + + let target_ref = repo + .find_branch(&tb, git2::BranchType::Local) + .map_err(|_| anyhow::anyhow!("branch '{}' not found", tb))?; + let target_oid = target_ref + .get() + .target() + .ok_or_else(|| anyhow::anyhow!("branch has no target"))? + .to_string(); + + Ok((source_oid, target_oid)) + }); + + self.executor.spawn(async move { + let (source_oid, target_oid) = git_info + .join() + .map_err(|_| anyhow::anyhow!("git thread panicked"))??; + + let repo_id = repo_id_lock + .lock() + .await + .ok_or_else(|| anyhow::anyhow!("repo not registered"))?; + + let guard = store_lock.lock().await; + let store = guard + .as_ref() + .ok_or_else(|| anyhow::anyhow!("PR store not initialized"))?; + + let pr_id = store + .create_pr( + repo_id, + &title, + &source_branch, + &target_branch, + &source_oid, + &target_oid, + ) + .await + .map_err(|e| anyhow::anyhow!("{e}"))?; + + store.get_pr(pr_id).await.map_err(|e| anyhow::anyhow!("{e}")) + }) + } + + pub fn get_branch_diff( + &self, + source_oid: String, + target_oid: String, + ) -> Task> { + let Some(repo_path) = self.repo_path.clone() else { + return Task::ready(Err(anyhow::anyhow!("no repo path"))); + }; + let diff_service = self.diff_service.clone(); + + self.executor.spawn(async move { + let mut ds = diff_service.lock().unwrap(); + ds.get_branch_diff(&repo_path, &source_oid, &target_oid) + .map_err(|e| anyhow::anyhow!("{e}")) + }) + } + + pub fn update_pr_status( + &self, + pr_id: i64, + status: PrStatus, + ) -> Task> { + let store_lock = self.store.clone(); + + self.executor.spawn(async move { + let guard = store_lock.lock().await; + let store = guard + .as_ref() + .ok_or_else(|| anyhow::anyhow!("PR store not initialized"))?; + store + .update_pr_status(pr_id, status) + .await + .map_err(|e| anyhow::anyhow!("{e}")) + }) + } + + pub fn get_unified_diff( + &self, + source_oid: String, + target_oid: String, + ) -> Task> { + let Some(repo_path) = self.repo_path.clone() else { + return Task::ready(Err(anyhow::anyhow!("no repo path"))); + }; + + self.executor.spawn(async move { + let repo = git2::Repository::discover(&repo_path)?; + let source = git2::Oid::from_str(&source_oid)?; + let target = git2::Oid::from_str(&target_oid)?; + + let merge_base = repo.merge_base(source, target)?; + let base_tree = repo.find_commit(merge_base)?.tree()?; + let source_tree = repo.find_commit(source)?.tree()?; + + let mut diff_opts = git2::DiffOptions::new(); + diff_opts.patience(true).context_lines(3); + + let diff = repo.diff_tree_to_tree( + Some(&base_tree), + Some(&source_tree), + Some(&mut diff_opts), + )?; + + let mut output = String::new(); + diff.print(git2::DiffFormat::Patch, |_delta, _hunk, line| { + let prefix = match line.origin() { + '+' => "+", + '-' => "-", + ' ' => " ", + 'H' | 'F' => "", + _ => "", + }; + let content = std::str::from_utf8(line.content()).unwrap_or(""); + if line.origin() == 'H' || line.origin() == 'F' { + output.push_str(content); + } else { + output.push_str(prefix); + output.push_str(content); + } + true + })?; + + Ok(output) + }) + } + + pub fn get_pr(&self, pr_id: i64) -> Task> { + let store_lock = self.store.clone(); + + self.executor.spawn(async move { + let guard = store_lock.lock().await; + let store = guard + .as_ref() + .ok_or_else(|| anyhow::anyhow!("PR store not initialized"))?; + store.get_pr(pr_id).await.map_err(|e| anyhow::anyhow!("{e}")) + }) + } +} diff --git a/crates/cherrypick_ui/src/remote_toolbar.rs b/crates/cherrypick_ui/src/remote_toolbar.rs new file mode 100644 index 00000000000000..9e067ebde612d2 --- /dev/null +++ b/crates/cherrypick_ui/src/remote_toolbar.rs @@ -0,0 +1,203 @@ +use askpass::AskPassDelegate; +use git::repository::FetchOptions; +use gpui::{ + App, Context, Entity, FocusHandle, Focusable, IntoElement, Render, SharedString, WeakEntity, + Window, div, +}; +use project::git_store::Repository; +use ui::prelude::*; +use workspace::Workspace; + +pub struct RemoteToolbar { + focus_handle: FocusHandle, + workspace: WeakEntity, + repository: Option>, + ahead: u32, + behind: u32, +} + +impl RemoteToolbar { + pub fn new( + workspace: WeakEntity, + repository: Option>, + cx: &mut Context, + ) -> Self { + let (ahead, behind) = Self::tracking_counts(&repository, cx); + Self { + focus_handle: cx.focus_handle(), + workspace, + repository, + ahead, + behind, + } + } + + pub fn set_repository( + &mut self, + repo: Option>, + cx: &mut Context, + ) { + self.repository = repo; + let (ahead, behind) = Self::tracking_counts(&self.repository, cx); + self.ahead = ahead; + self.behind = behind; + cx.notify(); + } + + fn tracking_counts(repo: &Option>, cx: &App) -> (u32, u32) { + repo.as_ref() + .and_then(|r| { + r.read(cx) + .branch + .as_ref() + .and_then(|b| b.tracking_status()) + .map(|s| (s.ahead, s.behind)) + }) + .unwrap_or((0, 0)) + } + + fn noop_askpass(cx: &mut Context) -> AskPassDelegate { + AskPassDelegate::new(&mut cx.to_async(), |_prompt, _tx, _cx| {}) + } + + fn do_fetch(&mut self, window: &mut Window, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + let askpass = Self::noop_askpass(cx); + window + .spawn(cx, async move |cx| { + let rx = repo.update(cx, |repo, cx| { + repo.fetch(FetchOptions::All, askpass, cx) + }); + let _ = rx.await; + anyhow::Ok(()) + }) + .detach(); + } + + fn do_pull(&mut self, window: &mut Window, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + let askpass = Self::noop_askpass(cx); + let branch_name: Option = self + .repository + .as_ref() + .and_then(|r| { + r.read(cx) + .branch + .as_ref() + .map(|b| SharedString::from(b.name().to_string())) + }); + + window + .spawn(cx, async move |cx| { + let rx = repo.update(cx, |repo, cx| { + repo.pull(branch_name, "origin".into(), false, askpass, cx) + }); + let _ = rx.await; + anyhow::Ok(()) + }) + .detach(); + } + + fn do_push(&mut self, window: &mut Window, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + let askpass = Self::noop_askpass(cx); + let branch_name = self + .repository + .as_ref() + .and_then(|r| r.read(cx).branch.as_ref().map(|b| b.name().to_string())) + .unwrap_or_default(); + + window + .spawn(cx, async move |cx| { + let rx = repo.update(cx, |repo, cx| { + repo.push( + SharedString::from(branch_name.clone()), + SharedString::from(branch_name), + "origin".into(), + None, + askpass, + cx, + ) + }); + let _ = rx.await; + anyhow::Ok(()) + }) + .detach(); + } +} + +impl Focusable for RemoteToolbar { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for RemoteToolbar { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + div() + .flex() + .items_center() + .gap_2() + .px_2() + .py_1() + .border_t_1() + .border_color(cx.theme().colors().border) + .child( + div() + .text_xs() + .text_color(cx.theme().colors().text_muted) + .child(format!("↑{} ↓{}", self.ahead, self.behind)), + ) + .child( + div() + .id("fetch-btn") + .text_xs() + .cursor_pointer() + .px_2() + .py(gpui::px(2.0)) + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, window, cx| { + this.do_fetch(window, cx); + })) + .child("Fetch"), + ) + .child( + div() + .id("pull-btn") + .text_xs() + .cursor_pointer() + .px_2() + .py(gpui::px(2.0)) + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, window, cx| { + this.do_pull(window, cx); + })) + .child("Pull"), + ) + .child( + div() + .id("push-btn") + .text_xs() + .cursor_pointer() + .px_2() + .py(gpui::px(2.0)) + .rounded_sm() + .bg(cx.theme().colors().ghost_element_background) + .hover(|s| s.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(|this, _event, window, cx| { + this.do_push(window, cx); + })) + .child("Push"), + ) + } +} diff --git a/crates/cherrypick_ui/src/repo_state_banner.rs b/crates/cherrypick_ui/src/repo_state_banner.rs new file mode 100644 index 00000000000000..80e841deabc2cf --- /dev/null +++ b/crates/cherrypick_ui/src/repo_state_banner.rs @@ -0,0 +1 @@ +// Placeholder — implemented in a later section diff --git a/crates/cherrypick_ui/src/settings.rs b/crates/cherrypick_ui/src/settings.rs new file mode 100644 index 00000000000000..80e841deabc2cf --- /dev/null +++ b/crates/cherrypick_ui/src/settings.rs @@ -0,0 +1 @@ +// Placeholder — implemented in a later section diff --git a/crates/cherrypick_ui/src/staging_view.rs b/crates/cherrypick_ui/src/staging_view.rs new file mode 100644 index 00000000000000..b1af9bee854dad --- /dev/null +++ b/crates/cherrypick_ui/src/staging_view.rs @@ -0,0 +1,294 @@ +use git::repository::RepoPath; +use git::status::{FileStatus, StatusCode, TrackedStatus}; +use gpui::{ + AnyElement, App, Context, Entity, FocusHandle, Focusable, FontWeight, IntoElement, Render, + SharedString, Subscription, Window, div, px, +}; +use project::git_store::{Repository, StatusEntry}; +use ui::prelude::*; + +struct StagingEntry { + repo_path: RepoPath, + status: FileStatus, + is_staged: bool, +} + +impl StagingEntry { + fn status_char(&self) -> char { + match self.status { + FileStatus::Untracked => '?', + FileStatus::Ignored => '!', + FileStatus::Unmerged { .. } => 'U', + FileStatus::Tracked(TrackedStatus { + index_status, + worktree_status, + }) => { + let code = if self.is_staged { + index_status + } else { + worktree_status + }; + match code { + StatusCode::Modified => 'M', + StatusCode::Added => 'A', + StatusCode::Deleted => 'D', + StatusCode::Renamed => 'R', + StatusCode::Copied => 'C', + StatusCode::TypeChanged => 'T', + StatusCode::Unmodified => ' ', + } + } + } + } +} + +pub struct StagingView { + focus_handle: FocusHandle, + repository: Option>, + staged: Vec, + unstaged: Vec, + commit_message: String, + _subscriptions: Vec, +} + +impl StagingView { + pub fn new( + repository: Option>, + cx: &mut Context, + ) -> Self { + let mut view = Self { + focus_handle: cx.focus_handle(), + repository: repository.clone(), + staged: Vec::new(), + unstaged: Vec::new(), + commit_message: String::new(), + _subscriptions: Vec::new(), + }; + view.refresh_statuses(cx); + view + } + + pub fn set_repository(&mut self, repo: Option>, cx: &mut Context) { + self.repository = repo; + self.refresh_statuses(cx); + } + + fn refresh_statuses(&mut self, cx: &mut Context) { + let Some(repo) = self.repository.as_ref() else { + self.staged.clear(); + self.unstaged.clear(); + cx.notify(); + return; + }; + + self.staged.clear(); + self.unstaged.clear(); + + let snapshot = repo.read(cx); + for entry in snapshot.status() { + let staging = entry.status.staging(); + match staging { + git::status::StageStatus::Staged => { + self.staged.push(StagingEntry { + repo_path: entry.repo_path, + status: entry.status, + is_staged: true, + }); + } + git::status::StageStatus::Unstaged => { + self.unstaged.push(StagingEntry { + repo_path: entry.repo_path, + status: entry.status, + is_staged: false, + }); + } + git::status::StageStatus::PartiallyStaged => { + self.staged.push(StagingEntry { + repo_path: entry.repo_path.clone(), + status: entry.status, + is_staged: true, + }); + self.unstaged.push(StagingEntry { + repo_path: entry.repo_path, + status: entry.status, + is_staged: false, + }); + } + } + } + cx.notify(); + } + + fn stage_file(&self, path: RepoPath, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + repo.update(cx, |repo, cx| { + repo.stage_entries(vec![path], cx).detach(); + }); + } + + fn unstage_file(&self, path: RepoPath, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + repo.update(cx, |repo, cx| { + repo.unstage_entries(vec![path], cx).detach(); + }); + } + + fn stage_all(&self, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + repo.update(cx, |repo, cx| { + repo.stage_all(cx).detach(); + }); + } + + fn unstage_all(&self, cx: &mut Context) { + let Some(repo) = self.repository.clone() else { + return; + }; + repo.update(cx, |repo, cx| { + repo.unstage_all(cx).detach(); + }); + } + + fn render_file_entry( + &self, + entry: &StagingEntry, + cx: &mut Context, + ) -> AnyElement { + let path_str = entry.repo_path.as_unix_str().to_string(); + let status_char = entry.status_char(); + let is_staged = entry.is_staged; + let repo_path = entry.repo_path.clone(); + + div() + .id(SharedString::from(format!( + "file-{}-{}", + if is_staged { "s" } else { "u" }, + &path_str + ))) + .flex() + .items_center() + .gap_1() + .px_2() + .py(px(2.0)) + .cursor_pointer() + .hover(|style| style.bg(cx.theme().colors().ghost_element_hover)) + .on_click(cx.listener(move |this, _event, _window, cx| { + if is_staged { + this.unstage_file(repo_path.clone(), cx); + } else { + this.stage_file(repo_path.clone(), cx); + } + })) + .child( + div() + .text_xs() + .font_weight(FontWeight::BOLD) + .text_color(match status_char { + 'M' => cx.theme().colors().version_control_modified, + 'A' => cx.theme().colors().version_control_added, + 'D' => cx.theme().colors().version_control_deleted, + 'U' => cx.theme().colors().version_control_conflict, + _ => cx.theme().colors().text_muted, + }) + .child(format!("{}", status_char)), + ) + .child( + div() + .flex_grow() + .text_sm() + .overflow_x_hidden() + .whitespace_nowrap() + .child(path_str), + ) + .into_any_element() + } +} + +impl Focusable for StagingView { + fn focus_handle(&self, _cx: &App) -> FocusHandle { + self.focus_handle.clone() + } +} + +impl Render for StagingView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let staged_header = div() + .flex() + .items_center() + .justify_between() + .px_2() + .py_1() + .child( + div() + .text_xs() + .font_weight(FontWeight::SEMIBOLD) + .text_color(cx.theme().colors().text_muted) + .child(format!("Staged ({})", self.staged.len())), + ) + .child( + div() + .id("unstage-all-btn") + .text_xs() + .cursor_pointer() + .text_color(cx.theme().colors().text_accent) + .hover(|s| s.underline()) + .on_click(cx.listener(|this, _event, _window, cx| { + this.unstage_all(cx); + })) + .child("Unstage All"), + ); + + let staged_entries: Vec<_> = self + .staged + .iter() + .map(|e| self.render_file_entry(e, cx)) + .collect(); + + let unstaged_header = div() + .flex() + .items_center() + .justify_between() + .px_2() + .py_1() + .child( + div() + .text_xs() + .font_weight(FontWeight::SEMIBOLD) + .text_color(cx.theme().colors().text_muted) + .child(format!("Unstaged ({})", self.unstaged.len())), + ) + .child( + div() + .id("stage-all-btn") + .text_xs() + .cursor_pointer() + .text_color(cx.theme().colors().text_accent) + .hover(|s| s.underline()) + .on_click(cx.listener(|this, _event, _window, cx| { + this.stage_all(cx); + })) + .child("Stage All"), + ); + + let unstaged_entries: Vec<_> = self + .unstaged + .iter() + .map(|e| self.render_file_entry(e, cx)) + .collect(); + + div() + .flex() + .flex_col() + .size_full() + .child(unstaged_header) + .children(unstaged_entries) + .child(staged_header) + .children(staged_entries) + } +} diff --git a/crates/cherrypick_ui/src/stash_list.rs b/crates/cherrypick_ui/src/stash_list.rs new file mode 100644 index 00000000000000..80e841deabc2cf --- /dev/null +++ b/crates/cherrypick_ui/src/stash_list.rs @@ -0,0 +1 @@ +// Placeholder — implemented in a later section diff --git a/crates/cherrypick_ui/src/worktree_list.rs b/crates/cherrypick_ui/src/worktree_list.rs new file mode 100644 index 00000000000000..80e841deabc2cf --- /dev/null +++ b/crates/cherrypick_ui/src/worktree_list.rs @@ -0,0 +1 @@ +// Placeholder — implemented in a later section diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index d8ac8be3369f7f..564aee164828ff 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -111,6 +111,7 @@ feedback.workspace = true file_finder.workspace = true fs.workspace = true futures.workspace = true +cherrypick_ui.workspace = true git.workspace = true git_graph.workspace = true git_hosting_providers.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 9141fe1aa8ae31..90c3790bf8bd5d 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -730,6 +730,7 @@ fn main() { notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx); collab_ui::init(&app_state, cx); git_ui::init(cx); + cherrypick_ui::init(cx); git_graph::init(cx); feedback::init(cx); markdown_preview::init(cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 6d1a9c176f1193..2cd359e390b08a 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -29,6 +29,7 @@ use fs::Fs; use futures::FutureExt as _; use futures::{StreamExt, channel::mpsc, select_biased}; use git_ui::commit_view::CommitViewToolbar; +use cherrypick_ui::CherryPickSidebar; use git_ui::git_panel::GitPanel; use git_ui::project_diff::{BranchDiffToolbar, ProjectDiffToolbar}; use gpui::{ @@ -714,6 +715,7 @@ fn initialize_panels(window: &mut Window, cx: &mut Context) -> Task) -> Task