diff --git a/Cargo.toml b/Cargo.toml index da2a7bb4..8c19bcf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ unicode-width = "0.2" pulldown-cmark = { version = "0.13", default-features = false } tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } cron = "0.16.0" -chrono = "0.4.44" +chrono = { version = "0.4.44", features = ["serde"] } chrono-tz = "0.10.4" [target.'cfg(unix)'.dependencies] diff --git a/docs/slash-commands.md b/docs/slash-commands.md index 6d24a63a..040838e5 100644 --- a/docs/slash-commands.md +++ b/docs/slash-commands.md @@ -10,6 +10,7 @@ OpenAB registers Discord slash commands for session control. These work in both | `/agents` | Select the agent mode via dropdown menu | Yes | | `/cancel` | Cancel the current in-flight operation | Yes | | `/reset` | Reset the conversation session (clear history, start fresh) | Yes | +| `/remind` | Set a one-shot delayed reminder to mention users/roles | No | All responses are **ephemeral** — only the user who invoked the command sees the reply. @@ -74,3 +75,45 @@ In addition to slash commands, you can pass built-in CLI commands directly after ``` These are forwarded as-is to the ACP session as a prompt. Any command the underlying CLI supports in its interactive mode works here. This is the recommended workaround for agents that don't expose `configOptions`. + +## `/remind` + +Set a one-shot delayed reminder that mentions users or roles in the channel after a specified delay. + +**Syntax:** +``` +/remind targets:<@user @role ...> message: delay: +``` + +**Parameters:** + +| Parameter | Required | Description | +|-----------|----------|-------------| +| `targets` | Yes | Space-separated @mentions (users and/or roles) | +| `message` | Yes | Reminder text | +| `delay` | Yes | Duration before firing: `1m` to `30d` (supports `m`, `h`, `d` and combinations like `1h30m`) | + +**Constraints:** +- Only humans can use `/remind` (bots are rejected) +- Minimum delay: 1 minute +- Maximum delay: 30 days +- Maximum message length: 1800 characters +- Maximum 5 active reminders per user +- Maximum 10 mention targets per reminder (use a @role for larger groups) +- `@everyone` and `@here` in messages are automatically neutralized (will not trigger mass mentions) +- One-shot only (fires once, then removed) +- Reminders persist across bot restarts (stored in `$HOME/.openab/reminders.json`) + +**Examples:** +``` +/remind targets:@Alice @Bob message:Review PR #42 delay:2h +/remind targets:@Reviewers message:Stand-up time delay:30m +/remind targets:@Charlie message:Check deployment delay:1d +``` + +**When fired, the bot posts:** +``` +⏰ Reminder from @sender: +"Review PR #42" +cc @Alice @Bob +``` diff --git a/src/discord.rs b/src/discord.rs index ce2b5e11..5452e6a9 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -5,15 +5,16 @@ use crate::bot_turns::{BotTurnTracker, TurnAction, TurnSeverity}; use crate::config::{AllowBots, AllowUsers, SttConfig}; use crate::format; use crate::media; +use crate::remind::{self, ReminderStore}; use async_trait::async_trait; use serenity::builder::{ - CreateActionRow, CreateButton, CreateCommand, CreateInteractionResponse, + CreateActionRow, CreateButton, CreateCommand, CreateCommandOption, CreateInteractionResponse, CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind, CreateSelectMenuOption, CreateThread, EditMessage, }; use serenity::http::Http; use serenity::model::application::ButtonStyle; -use serenity::model::application::{Command, ComponentInteractionDataKind, Interaction}; +use serenity::model::application::{Command, CommandOptionType, ComponentInteractionDataKind, Interaction}; use serenity::model::channel::{AutoArchiveDuration, Message, MessageType, ReactionType}; use serenity::model::gateway::Ready; use serenity::model::id::{ChannelId, MessageId, UserId}; @@ -207,6 +208,10 @@ pub struct Handler { pub allow_dm: bool, /// Per-thread dispatcher (Message mode uses cap=1 for FIFO; Thread/Lane use configured cap). pub dispatcher: Arc, + /// Reminder store for /remind slash command. + pub reminder_store: ReminderStore, + /// Track scheduled reminder IDs to prevent duplicate scheduling on reconnect. + pub scheduled_ids: tokio::sync::Mutex>, } impl Handler { @@ -815,6 +820,23 @@ impl EventHandler for Handler { CreateCommand::new("cancel-all") .description("Cancel current operation and drop all buffered messages"), CreateCommand::new("reset").description("Reset the conversation session"), + CreateCommand::new("remind") + .description("Set a one-shot reminder to mention users/roles after a delay") + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "targets", + "Users/roles to mention (e.g. @user1 @role1)", + ).required(true)) + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "message", + "Reminder message", + ).required(true)) + .add_option(CreateCommandOption::new( + CommandOptionType::String, + "delay", + "Delay before firing (e.g. 30m, 2h, 1d)", + ).required(true)), ]; // Register global commands (works in DMs + all guilds after propagation). @@ -833,6 +855,22 @@ impl EventHandler for Handler { info!(%guild_id, "registered guild slash commands"); } } + + // Re-schedule any pending reminders that survived a restart. + let pending = self.reminder_store.pending().await; + if !pending.is_empty() { + let mut scheduled = self.scheduled_ids.lock().await; + let mut count = 0; + for r in pending { + if scheduled.insert(r.id.clone()) { + remind::schedule_reminder(ctx.http.clone(), self.reminder_store.clone(), r); + count += 1; + } + } + if count > 0 { + info!(count, "re-scheduled pending reminders"); + } + } } async fn interaction_create(&self, ctx: Context, interaction: Interaction) { @@ -854,6 +892,9 @@ impl EventHandler for Handler { Interaction::Command(cmd) if cmd.data.name == "reset" => { self.handle_reset_command(&ctx, &cmd).await; } + Interaction::Command(cmd) if cmd.data.name == "remind" => { + self.handle_remind_command(&ctx, &cmd).await; + } Interaction::Component(comp) if comp.data.custom_id.starts_with("acp_config_") => { self.handle_config_select(&ctx, &comp).await; } @@ -1116,6 +1157,145 @@ impl Handler { } } + async fn handle_remind_command( + &self, + ctx: &Context, + cmd: &serenity::model::application::CommandInteraction, + ) { + // Only humans can use /remind + if cmd.user.bot { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ Only humans can set reminders.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // Extract options + let opts = &cmd.data.options; + let targets_raw = opts.iter() + .find(|o| o.name == "targets") + .and_then(|o| o.value.as_str()) + .unwrap_or(""); + let message = opts.iter() + .find(|o| o.name == "message") + .and_then(|o| o.value.as_str()) + .unwrap_or(""); + let delay_raw = opts.iter() + .find(|o| o.name == "delay") + .and_then(|o| o.value.as_str()) + .unwrap_or(""); + + if targets_raw.is_empty() || message.is_empty() || delay_raw.is_empty() { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ All fields (targets, message, delay) are required.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // Parse delay + let delay_secs = match remind::parse_delay(delay_raw) { + Ok(s) => s, + Err(e) => { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ Invalid delay: {e}")) + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + }; + + if let Err(e) = remind::validate_message(message) { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ {e}")) + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // Strip @everyone / @here to prevent unintended mass pings. + let message = remind::sanitize_message(message); + + // Extract mention strings from targets (keep raw — Discord renders them) + let targets: Vec = targets_raw + .split_whitespace() + .filter(|t| t.starts_with("<@") && t.ends_with('>')) + .map(|t| t.to_string()) + .collect(); + + if targets.is_empty() { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ No valid mentions found in targets. Use @user or @role.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + if targets.len() > remind::MAX_TARGETS { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!("⚠️ Too many targets (max {}). Use a @role instead.", remind::MAX_TARGETS)) + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + // F4: Per-user rate limit (max 5 active reminders) + let user_id = cmd.user.id.get(); + let pending = self.reminder_store.pending().await; + let user_count = pending.iter().filter(|r| r.sender_id == user_id).count(); + if user_count >= 5 { + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("⚠️ You already have 5 active reminders. Wait for some to fire before adding more.") + .ephemeral(true), + ); + let _ = cmd.create_response(&ctx.http, response).await; + return; + } + + let fire_at = chrono::Utc::now() + chrono::Duration::seconds(delay_secs as i64); + let reminder = remind::Reminder { + id: uuid::Uuid::new_v4().to_string(), + channel_id: cmd.channel_id.get(), + sender_id: cmd.user.id.get(), + targets: targets.clone(), + message: message.clone(), + fire_at, + created_at: chrono::Utc::now(), + }; + + // Persist and schedule + self.reminder_store.add(reminder.clone()).await; + self.scheduled_ids.lock().await.insert(reminder.id.clone()); + remind::schedule_reminder(ctx.http.clone(), self.reminder_store.clone(), reminder); + + let delay_str = remind::format_delay(delay_secs); + let response = CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!( + "⏰ Reminder set! Will fire in **{delay_str}** and mention {}", + targets.join(" ") + )) + .ephemeral(true), + ); + if let Err(e) = cmd.create_response(&ctx.http, response).await { + tracing::error!(error = %e, "failed to respond to /remind command"); + } + } + async fn handle_config_select( &self, ctx: &Context, diff --git a/src/main.rs b/src/main.rs index 706079b6..413eb114 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ mod gateway; mod markdown; mod media; mod reactions; +mod remind; mod setup; mod slack; mod stt; @@ -403,6 +404,14 @@ async fn main() -> anyhow::Result<()> { )); dispatchers.lock().unwrap().push(discord_dispatcher.clone()); + // Initialize reminder store (persists to $HOME/.openab/reminders.json) + let reminder_path = std::env::var("HOME") + .map(std::path::PathBuf::from) + .unwrap_or_default() + .join(".openab") + .join("reminders.json"); + let reminder_store = remind::ReminderStore::load(reminder_path); + let handler = discord::Handler { router, allow_all_channels, @@ -424,6 +433,8 @@ async fn main() -> anyhow::Result<()> { )), allow_dm: discord_cfg.allow_dm, dispatcher: discord_dispatcher, + reminder_store: reminder_store.clone(), + scheduled_ids: tokio::sync::Mutex::new(std::collections::HashSet::new()), }; let intents = GatewayIntents::GUILD_MESSAGES diff --git a/src/remind.rs b/src/remind.rs new file mode 100644 index 00000000..471b08ff --- /dev/null +++ b/src/remind.rs @@ -0,0 +1,399 @@ +//! One-shot `/remind` slash command — schedules a delayed mention in a Discord channel. +//! +//! Persistence: reminders are stored in `reminders.json` and reloaded on startup. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serenity::http::Http; +use serenity::model::id::ChannelId; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::Mutex; +use tracing::{error, info, warn}; + +/// A single pending reminder. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Reminder { + pub id: String, + pub channel_id: u64, + pub sender_id: u64, + /// Raw mention strings (e.g. "<@123>", "<@&456>") + pub targets: Vec, + pub message: String, + pub fire_at: DateTime, + pub created_at: DateTime, +} + +/// Shared reminder store with file persistence. +#[derive(Clone)] +pub struct ReminderStore { + reminders: Arc>>, + path: PathBuf, +} + +impl ReminderStore { + /// Load or create the reminder store from the given path. + pub fn load(path: PathBuf) -> Self { + let reminders = match std::fs::read_to_string(&path) { + Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { + warn!(error = %e, "failed to parse reminders.json, starting empty"); + Vec::new() + }), + Err(_) => Vec::new(), + }; + info!(count = reminders.len(), path = %path.display(), "loaded reminders"); + Self { + reminders: Arc::new(Mutex::new(reminders)), + path, + } + } + + /// Add a reminder and persist to disk. + pub async fn add(&self, reminder: Reminder) { + let snapshot = { + let mut reminders = self.reminders.lock().await; + reminders.push(reminder); + reminders.clone() + }; + self.persist(&snapshot); + } + + /// Remove a reminder by ID and persist. + pub async fn remove(&self, id: &str) { + let snapshot = { + let mut reminders = self.reminders.lock().await; + reminders.retain(|r| r.id != id); + reminders.clone() + }; + self.persist(&snapshot); + } + + /// Get all pending reminders (for startup re-scheduling). + pub async fn pending(&self) -> Vec { + self.reminders.lock().await.clone() + } + + fn persist(&self, reminders: &[Reminder]) { + match serde_json::to_string_pretty(reminders) { + Ok(data) => { + if let Some(parent) = self.path.parent() { + if let Err(e) = std::fs::create_dir_all(parent) { + error!(error = %e, "failed to create reminders directory"); + return; + } + } + if let Err(e) = std::fs::write(&self.path, data) { + error!(error = %e, "failed to persist reminders.json"); + } + } + Err(e) => { + error!(error = %e, "failed to serialize reminders, skipping persist"); + } + } + } +} + +/// Maximum allowed message length for reminders. +pub const MAX_MESSAGE_LEN: usize = 1800; + +/// Maximum number of mention targets per reminder. +pub const MAX_TARGETS: usize = 10; + +/// Sanitize reminder message: neutralize @everyone/@here. +pub fn sanitize_message(msg: &str) -> String { + msg.replace("@everyone", "@\u{200b}everyone") + .replace("@here", "@\u{200b}here") +} + +/// Validate reminder message length. +pub fn validate_message(msg: &str) -> Result<(), String> { + if msg.len() > MAX_MESSAGE_LEN { + Err(format!("message too long (max {MAX_MESSAGE_LEN} characters)")) + } else { + Ok(()) + } +} + +/// Parse a human delay string like "30m", "2h", "7d" into seconds. +/// Supports combinations: "1h30m", "2d12h". +/// Range: 1m (60s) to 30d (2_592_000s). +pub fn parse_delay(input: &str) -> Result { + let s = input.trim().to_lowercase(); + if s.is_empty() { + return Err("empty delay".into()); + } + + let mut total_secs: u64 = 0; + let mut num_buf = String::new(); + + for ch in s.chars() { + if ch.is_ascii_digit() { + num_buf.push(ch); + } else { + let n: u64 = num_buf.parse().map_err(|_| format!("invalid number in delay: {input}"))?; + num_buf.clear(); + let multiplier = match ch { + 'm' => 60, + 'h' => 3600, + 'd' => 86400, + _ => return Err(format!("unknown unit '{ch}' in delay (use m/h/d)")), + }; + total_secs += n * multiplier; + } + } + + // Handle bare number (default to minutes) + if !num_buf.is_empty() { + let n: u64 = num_buf.parse().map_err(|_| format!("invalid number in delay: {input}"))?; + total_secs += n * 60; // default unit = minutes + } + + if total_secs < 60 { + return Err("minimum delay is 1m".into()); + } + if total_secs > 2_592_000 { + return Err("maximum delay is 30d".into()); + } + + Ok(total_secs) +} + +/// Format seconds into a human-readable string like "2h 30m". +pub fn format_delay(secs: u64) -> String { + let d = secs / 86400; + let h = (secs % 86400) / 3600; + let m = (secs % 3600) / 60; + let mut parts = Vec::new(); + if d > 0 { parts.push(format!("{d}d")); } + if h > 0 { parts.push(format!("{h}h")); } + if m > 0 { parts.push(format!("{m}m")); } + if parts.is_empty() { "< 1m".into() } else { parts.join(" ") } +} + +/// Spawn a tokio task that fires the reminder after the delay. +pub fn schedule_reminder( + http: Arc, + store: ReminderStore, + reminder: Reminder, +) { + let now = Utc::now(); + let delay = if reminder.fire_at > now { + (reminder.fire_at - now).to_std().unwrap_or_default() + } else { + std::time::Duration::ZERO + }; + + let id = reminder.id.clone(); + tokio::spawn(async move { + tokio::time::sleep(delay).await; + + let targets_str = reminder.targets.join(" "); + let content = format!( + "⏰ **Reminder** from <@{}>:\n\"{}\"\ncc {}", + reminder.sender_id, reminder.message, targets_str + ); + + let channel = ChannelId::new(reminder.channel_id); + match channel.say(&http, &content).await { + Ok(_) => { + info!(id = %id, channel = reminder.channel_id, "reminder fired"); + store.remove(&id).await; + } + Err(e) => { + error!(error = %e, id = %id, "failed to send reminder — keeping for retry on next restart"); + } + } + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_delay_minutes() { + assert_eq!(parse_delay("5m").unwrap(), 300); + assert_eq!(parse_delay("1m").unwrap(), 60); + } + + #[test] + fn test_parse_delay_hours() { + assert_eq!(parse_delay("2h").unwrap(), 7200); + } + + #[test] + fn test_parse_delay_days() { + assert_eq!(parse_delay("1d").unwrap(), 86400); + assert_eq!(parse_delay("30d").unwrap(), 2_592_000); + } + + #[test] + fn test_parse_delay_combined() { + assert_eq!(parse_delay("1h30m").unwrap(), 5400); + assert_eq!(parse_delay("1d12h").unwrap(), 129_600); + } + + #[test] + fn test_parse_delay_bare_number_defaults_to_minutes() { + assert_eq!(parse_delay("10").unwrap(), 600); + } + + #[test] + fn test_parse_delay_too_short() { + assert!(parse_delay("0m").is_err()); + assert!(parse_delay("0h").is_err()); + } + + #[test] + fn test_parse_delay_too_long() { + assert!(parse_delay("31d").is_err()); + } + + #[test] + fn test_format_delay() { + assert_eq!(format_delay(3600), "1h"); + assert_eq!(format_delay(5400), "1h 30m"); + assert_eq!(format_delay(90000), "1d 1h"); + } + + #[test] + fn test_parse_delay_empty() { + assert!(parse_delay("").is_err()); + assert!(parse_delay(" ").is_err()); + } + + #[test] + fn test_parse_delay_invalid_unit() { + assert!(parse_delay("2x").is_err()); + assert!(parse_delay("abc").is_err()); + assert!(parse_delay("5s").is_err()); + } + + #[test] + fn test_parse_delay_case_insensitive() { + assert_eq!(parse_delay("2H").unwrap(), 7200); + assert_eq!(parse_delay("1D30M").unwrap(), 88200); + } + + #[test] + fn test_parse_delay_whitespace_trimmed() { + assert_eq!(parse_delay(" 5m ").unwrap(), 300); + } + + #[test] + fn test_parse_delay_bare_number_boundary() { + assert_eq!(parse_delay("1").unwrap(), 60); // 1 min + assert_eq!(parse_delay("30").unwrap(), 1800); // 30 min + } + + #[test] + fn test_parse_delay_exact_boundaries() { + // Exactly 1m (minimum) + assert_eq!(parse_delay("1m").unwrap(), 60); + // Exactly 30d (maximum) + assert_eq!(parse_delay("30d").unwrap(), 2_592_000); + // Just over 30d + assert!(parse_delay("30d1m").is_err()); + } + + #[test] + fn test_format_delay_zero() { + assert_eq!(format_delay(0), "< 1m"); + } + + #[test] + fn test_format_delay_pure_units() { + assert_eq!(format_delay(86400), "1d"); + assert_eq!(format_delay(120), "2m"); + assert_eq!(format_delay(7200), "2h"); + } + + #[tokio::test] + async fn test_reminder_store_add_remove() { + let dir = std::env::temp_dir().join(format!("remind_test_{}", std::process::id())); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("reminders.json"); + + let store = ReminderStore::load(path.clone()); + assert_eq!(store.pending().await.len(), 0); + + let r = Reminder { + id: "test-1".into(), + channel_id: 123, + sender_id: 456, + targets: vec!["<@789>".into()], + message: "hello".into(), + fire_at: Utc::now() + chrono::Duration::hours(1), + created_at: Utc::now(), + }; + + store.add(r).await; + assert_eq!(store.pending().await.len(), 1); + + store.remove("test-1").await; + assert_eq!(store.pending().await.len(), 0); + + // Verify persistence + let store2 = ReminderStore::load(path.clone()); + assert_eq!(store2.pending().await.len(), 0); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[tokio::test] + async fn test_reminder_store_persists_across_reload() { + let dir = std::env::temp_dir().join(format!("remind_test2_{}", std::process::id())); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("reminders.json"); + + let store = ReminderStore::load(path.clone()); + let r = Reminder { + id: "persist-1".into(), + channel_id: 100, + sender_id: 200, + targets: vec!["<@300>".into()], + message: "persist test".into(), + fire_at: Utc::now() + chrono::Duration::hours(2), + created_at: Utc::now(), + }; + store.add(r).await; + + // Reload from disk + let store2 = ReminderStore::load(path.clone()); + let pending = store2.pending().await; + assert_eq!(pending.len(), 1); + assert_eq!(pending[0].id, "persist-1"); + assert_eq!(pending[0].message, "persist test"); + + std::fs::remove_dir_all(&dir).ok(); + } + + #[test] + fn test_sanitize_message_strips_everyone_here() { + assert_eq!(sanitize_message("hello @everyone"), "hello @\u{200b}everyone"); + assert_eq!(sanitize_message("hey @here check"), "hey @\u{200b}here check"); + assert_eq!(sanitize_message("@everyone @here"), "@\u{200b}everyone @\u{200b}here"); + } + + #[test] + fn test_sanitize_message_no_change() { + assert_eq!(sanitize_message("normal message"), "normal message"); + assert_eq!(sanitize_message("<@123> hello"), "<@123> hello"); + } + + #[test] + fn test_validate_message_ok() { + assert!(validate_message("short message").is_ok()); + assert!(validate_message(&"a".repeat(1800)).is_ok()); + } + + #[test] + fn test_validate_message_too_long() { + assert!(validate_message(&"a".repeat(1801)).is_err()); + } + + #[test] + fn test_max_targets_constant() { + assert_eq!(MAX_TARGETS, 10); + } +}