From fc8ce9987cff997d8a8aa7e8d1265ebfd87a7428 Mon Sep 17 00:00:00 2001 From: zdevito Date: Wed, 19 Nov 2025 16:43:20 -0800 Subject: [PATCH] remote rust_local_mesh Differential Revision: [D87496994](https://our.internmc.facebook.com/intern/diff/D87496994/) [ghstack-poisoned] --- .github/workflows/test-gpu-rust.yml | 1 - Cargo.toml | 2 - controller/Cargo.toml | 43 - controller/build.rs | 25 - controller/src/bootstrap.rs | 656 ------ controller/src/history.rs | 842 ------- controller/src/lib.rs | 1944 ----------------- controller/src/main.rs | 88 - monarch_extension/Cargo.toml | 4 +- monarch_extension/src/lib.rs | 9 - monarch_extension/src/simulator_client.rs | 121 - monarch_simulator/Cargo.toml | 35 - monarch_simulator/src/bootstrap.rs | 161 -- .../src/collective_coordinator.rs | 205 -- monarch_simulator/src/controller.rs | 290 --- monarch_simulator/src/lib.rs | 36 - monarch_simulator/src/simulator.rs | 165 -- monarch_simulator/src/worker.rs | 941 -------- .../monarch_extension/simulator_client.pyi | 56 - python/monarch/_testing.py | 33 - python/monarch/common/constants.py | 3 - python/monarch/mesh_controller.py | 4 - python/monarch/rust_local_mesh.py | 1401 ------------ python/monarch/sim_mesh.py | 350 --- python/monarch/timer/README.md | 5 +- python/monarch/timer/example_monarch.py | 39 +- python/tests/test_controller.py | 835 ------- python/tests/test_fault_tolerance.py | 383 ---- python/tests/test_sim_backend.py | 51 - scripts/build_monarch_for_docs.sh | 3 +- scripts/generate_cargo_deps_graph.py | 343 +++ 31 files changed, 366 insertions(+), 8708 deletions(-) delete mode 100644 controller/Cargo.toml delete mode 100644 controller/build.rs delete mode 100644 controller/src/bootstrap.rs delete mode 100644 controller/src/history.rs delete mode 100644 controller/src/lib.rs delete mode 100644 controller/src/main.rs delete mode 100644 monarch_extension/src/simulator_client.rs delete mode 100644 monarch_simulator/Cargo.toml delete mode 100644 monarch_simulator/src/bootstrap.rs delete mode 100644 monarch_simulator/src/collective_coordinator.rs delete mode 100644 monarch_simulator/src/controller.rs delete mode 100644 monarch_simulator/src/lib.rs delete mode 100644 monarch_simulator/src/simulator.rs delete mode 100644 monarch_simulator/src/worker.rs delete mode 100644 python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi delete mode 100644 python/monarch/rust_local_mesh.py delete mode 100644 python/monarch/sim_mesh.py delete mode 100644 python/tests/test_controller.py delete mode 100644 python/tests/test_fault_tolerance.py delete mode 100644 python/tests/test_sim_backend.py create mode 100644 scripts/generate_cargo_deps_graph.py diff --git a/.github/workflows/test-gpu-rust.yml b/.github/workflows/test-gpu-rust.yml index bd6243ece..878e9167f 100644 --- a/.github/workflows/test-gpu-rust.yml +++ b/.github/workflows/test-gpu-rust.yml @@ -66,7 +66,6 @@ jobs: timeout 12m cargo nextest run --workspace --profile ci \ --exclude monarch_messages \ --exclude monarch_tensor_worker \ - --exclude monarch_simulator_lib \ --exclude torch-sys \ --exclude torch-sys-cuda # Copy the test results to the expected location diff --git a/Cargo.toml b/Cargo.toml index b037643a5..1b755b781 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ resolver = "2" members = [ "build_utils", - "controller", "cuda-sys", "erased_lifetime", "hyper", @@ -18,7 +17,6 @@ members = [ "monarch_messages", "monarch_perfetto_trace", "monarch_rdma", - "monarch_simulator", "monarch_tensor_worker", "monarch_types", "nccl-sys", diff --git a/controller/Cargo.toml b/controller/Cargo.toml deleted file mode 100644 index 4b066da1c..000000000 --- a/controller/Cargo.toml +++ /dev/null @@ -1,43 +0,0 @@ -# @generated by autocargo from //monarch/controller:[controller,controller-bin] - -[package] -name = "controller" -version = "0.0.0" -authors = ["Meta"] -edition = "2021" -license = "BSD-3-Clause" - -[lib] -edition = "2024" - -[[bin]] -name = "controller_bin" -path = "src/main.rs" -edition = "2024" - -[dependencies] -anyhow = "1.0.98" -async-trait = "0.1.86" -bincode = "1.3.3" -clap = { version = "4.5.42", features = ["derive", "env", "string", "unicode", "wrap_help"] } -const_format = "0.2" -hyperactor = { version = "0.0.0", path = "../hyperactor" } -hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" } -hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" } -monarch_messages = { version = "0.0.0", path = "../monarch_messages" } -nccl-sys = { path = "../nccl-sys" } -ndslice = { version = "0.0.0", path = "../ndslice" } -pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods", "py-clone"] } -serde = { version = "1.0.219", features = ["derive", "rc"] } -serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "raw_value", "unbounded_depth"] } -tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } -torch-sys = { path = "../torch-sys" } -tracing = { version = "0.1.41", features = ["attributes", "valuable"] } - -[dev-dependencies] -monarch_types = { version = "0.0.0", path = "../monarch_types" } -timed_test = { version = "0.0.0", path = "../timed_test" } -torch-sys = { version = "0.0.0", path = "../torch-sys" } - -[lints] -rust = { unexpected_cfgs = { check-cfg = ["cfg(fbcode_build)"], level = "warn" } } diff --git a/controller/build.rs b/controller/build.rs deleted file mode 100644 index 38d9aeaca..000000000 --- a/controller/build.rs +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -// This is needed due to the controller being built with torch/nccl deps due to monarch_messages. - -fn main() { - // `torch-sys` will set this env var through Cargo `links` metadata. - let lib_path = std::env::var("DEP_TORCH_LIB_PATH").expect("DEP_TORCH_LIB_PATH to be set"); - // Set the rpath so that the dynamic linker can find libtorch and friends. - println!("cargo::rustc-link-arg=-Wl,-rpath,{lib_path}"); - - if let Ok(path) = std::env::var("DEP_NCCL_LIB_PATH") { - println!("cargo::rustc-link-arg=-Wl,-rpath,{path}"); - } - - // Disable new dtags, as conda envs generally use `RPATH` over `RUNPATH`. - println!("cargo::rustc-link-arg=-Wl,--disable-new-dtags"); - - println!("cargo:rustc-link-lib=lzma"); -} diff --git a/controller/src/bootstrap.rs b/controller/src/bootstrap.rs deleted file mode 100644 index f737c5cff..000000000 --- a/controller/src/bootstrap.rs +++ /dev/null @@ -1,656 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::collections::HashMap; -use std::collections::HashSet; -use std::time::Duration; - -use anyhow::Result; -use anyhow::anyhow; -use clap::Args; -use clap::Subcommand; -use const_format::concatcp; -use hyperactor::GangRef; -use hyperactor::actor::ActorHandle; -use hyperactor::channel::ChannelAddr; -use hyperactor::clock::Clock; -use hyperactor::clock::RealClock; -use hyperactor::context; -use hyperactor::mailbox::open_port; -use hyperactor::reference::ActorId; -use hyperactor::reference::ActorRef; -use hyperactor::reference::GangId; -use hyperactor::reference::Index; -use hyperactor::reference::WorldId; -use hyperactor_mesh::comm::CommActor; -use hyperactor_multiprocess::System; -use hyperactor_multiprocess::proc_actor::Environment; -use hyperactor_multiprocess::proc_actor::ProcActor; -use hyperactor_multiprocess::proc_actor::ProcMessageClient; -use hyperactor_multiprocess::system_actor; -use hyperactor_multiprocess::system_actor::ProcLifecycleMode; -use hyperactor_multiprocess::system_actor::Shape; -use hyperactor_multiprocess::system_actor::SystemMessageClient; -use monarch_messages::worker::WorkerParams; -use pyo3::prelude::*; -use pyo3::types::PyType; -use serde::Deserialize; -use serde::Serialize; -use tokio::task::JoinHandle; - -use crate::ControllerActor; -use crate::ControllerParams; - -/// Domain name for all monarch reserved labels. -pub static MONARCH_LABEL_PREFIX: &str = "monarch.meta.com/"; -/// Prefix for all monarch reserved labels for procs. -static WORKER_LABEL_PREFIX: &str = concatcp!("proc.", MONARCH_LABEL_PREFIX); -/// Labels suffix indicating the role of a proc. -static LABEL_NAME_ROLE: &str = concatcp!(WORKER_LABEL_PREFIX, "role"); -/// Label value indicating proc role is controller. -static LABEL_VALUE_ROLE_CONTROLLER: &str = "controller"; -/// Label value indicating proc role is host. -static LABEL_VALUE_ROLE_HOST: &str = "host"; -/// Label indicating the worker world for a given controller to allow -/// for backreferencing. -static LABEL_NAME_WORKER_WORLD: &str = concatcp!(WORKER_LABEL_PREFIX, "workerWorld"); -/// The global name used for comm actors. -static COMM_ACTOR_NAME: &str = "comm"; - -/// Prefix for all monarch reserved labels for worlds. -pub static WORLD_LABEL_PREFIX: &str = concatcp!("world.", MONARCH_LABEL_PREFIX); -/// Label indicating if a given world is a worker world. A value of "1" indicates -/// a worker world. This allows us to query all worker worlds in the system. -static LABEL_NAME_WORKER: &str = concatcp!(WORLD_LABEL_PREFIX, "worker"); -/// Label indicating the controller actor id for a given worker world. This allows -/// to query all worker worlds and communcicate with their controllers. -static LABEL_NAME_CONTROLLER_ACTOR_ID: &str = concatcp!(WORLD_LABEL_PREFIX, "controllerActorId"); - -#[derive(Clone, Debug, Serialize, Deserialize, Args)] -#[pyclass(module = "monarch._rust_bindings.controller.bootstrap")] -pub struct ControllerCommand { - /// The worker world to create - #[arg(long)] - pub worker_world: String, - - /// The system address to bootstrap with. - #[arg(long)] - pub system_addr: String, - - /// The controller actor id to give to. - #[arg(long, default_value_t = String::from("controller[0].root"))] - pub controller_actor_id: String, - - // Global world size for this job - #[arg(long)] - pub world_size: usize, - - /// The number of processes per host. - #[arg(long, default_value_t = 8)] - pub num_procs_per_host: usize, - - /// The worker name. - #[arg(long, default_value_t = String::from("worker"))] - pub worker_name: String, - - /// The worker program to execute for each process. It is not needed if worker procs - /// are directly launched without management from host actors. - #[arg(long)] - pub program: Option, - - /// The supervision check interval in seconds. It indicates how often the controller - /// will poll system actor to check the status of all procs/actors in a world. This - /// decides how fast the client could observe a failure in the system. - #[arg(long, default_value_t = 2)] - pub supervision_query_interval_in_sec: u64, - - /// The supervision update interval in seconds, it indiciates how often the controller - /// proc should report its supervision status to the system. - #[arg(long, default_value_t = 2)] - pub supervision_update_interval_in_sec: u64, - - /// The worker progress check interval in seconds, it indicates how often the controller - /// will check that progress is being made. - #[arg(long, default_value_t = 10)] - pub worker_progress_check_interval_in_sec: u64, - - /// The operation timeout duration interval in seconds, it indicates how long we will allow - /// progress to stall for before letting the client know that worker(s) may be stuck. - #[arg(long, default_value_t = 120)] - pub operation_timeout_in_sec: u64, - - /// The number of operations invoked before we proactively check worker progress. If a large number - /// of operations are invoked all at once, it is expected that it will take a while for all operations - /// to complete so we want to inject progress requests at a higher frequency to check if we are making progress - #[arg(long, default_value_t = 100)] - pub operations_per_worker_progress_request: u64, - - /// If the controller should propagate a failure to the client if the workers become stuck. - #[arg(long, default_value_t = false)] - pub fail_on_worker_timeout: bool, - - /// If to launch the workers for CPU-only devices. - #[arg(long, default_value_t = false)] - pub is_cpu_worker: bool, - - /// Proc metadata which will be available through system. - #[arg(long, value_parser=parse_key_val)] - pub extra_proc_labels: Option>, -} - -#[pymethods] -impl ControllerCommand { - #[new] - #[pyo3(signature = (*, worker_world, system_addr, controller_actor_id, world_size, num_procs_per_host, worker_name, program, supervision_query_interval_in_sec, supervision_update_interval_in_sec, worker_progress_check_interval_in_sec, operation_timeout_in_sec, operations_per_worker_progress_request, fail_on_worker_timeout, is_cpu_worker, extra_proc_labels))] - fn new( - worker_world: String, - system_addr: String, - controller_actor_id: String, - world_size: usize, - num_procs_per_host: usize, - worker_name: String, - program: Option, - supervision_query_interval_in_sec: u64, - supervision_update_interval_in_sec: u64, - worker_progress_check_interval_in_sec: u64, - operation_timeout_in_sec: u64, - operations_per_worker_progress_request: u64, - fail_on_worker_timeout: bool, - is_cpu_worker: bool, - extra_proc_labels: Option>, - ) -> Self { - Self { - worker_world, - system_addr, - controller_actor_id, - world_size, - num_procs_per_host, - worker_name, - program, - supervision_query_interval_in_sec, - supervision_update_interval_in_sec, - worker_progress_check_interval_in_sec, - operation_timeout_in_sec, - operations_per_worker_progress_request, - fail_on_worker_timeout, - is_cpu_worker, - extra_proc_labels, - } - } -} - -/// The different types of hyperactor to launch based on the subcommands. -/// The ones for System / Host should probably be moved to the hyperactor -/// multiprocess crate. -#[derive(Clone, Debug, Serialize, Deserialize, Subcommand)] -#[pyclass(module = "monarch._rust_bindings.controller.bootstrap")] -pub enum RunCommand { - System { - /// The system address to bootstrap with. - #[arg(long)] - system_addr: String, - - /// The supervision update timeout in seconds. A proc is considered dead if system - /// doesn't get any supervision update from it within this timeout. - #[arg(long, default_value_t = 20)] - supervision_update_timeout_in_sec: u64, - - /// Evict a world if it has been unhealthy for this many seconds. - #[arg(long, default_value_t = 10)] - world_eviction_timeout_in_sec: u64, - }, - - Host { - /// The system address to bootstrap with. - #[arg(long)] - system_addr: String, - - /// The host world to create. - #[arg(long)] - host_world: String, - - /// The host rank; i.e., the index of the host in the world. - #[arg(long)] - host_rank: Index, - - /// The supervision update interval in seconds, it indiciates how often a proc should - /// report its supervision status to the system. - #[arg(long, default_value_t = 2)] - supervision_update_interval_in_sec: u64, - }, - - Controller(ControllerCommand), -} - -#[pyclass(frozen, module = "monarch._rust_bindings.controller.bootstrap")] -#[derive(Debug, Serialize, Deserialize)] -pub enum ControllerServerRequest { - Run(RunCommand), - Exit(), -} - -#[pymethods] -impl ControllerServerRequest { - fn to_json(&self) -> PyResult { - Ok(serde_json::to_string(self).map_err(|e| anyhow!(e))?) - } - - fn __str__(&self) -> String { - format!("{:?}", self) - } -} - -#[pyclass(frozen, module = "monarch._rust_bindings.controller.bootstrap")] -#[derive(Debug, Serialize, Deserialize)] -pub enum ControllerServerResponse { - Finished { error: Option }, -} - -#[pymethods] -impl ControllerServerResponse { - #[classmethod] - fn from_json(_: &Bound<'_, PyType>, json: &str) -> PyResult { - Ok(serde_json::from_str(json).map_err(|e| anyhow!(e))?) - } - - fn __str__(&self) -> String { - format!("{:?}", self) - } -} - -/// A helper function to launch the system, host, or controller actors. -/// Returns the handle to be waited on. -pub fn run(command: RunCommand) -> Result>> { - Ok(match command { - RunCommand::System { - system_addr, - supervision_update_timeout_in_sec, - world_eviction_timeout_in_sec, - } => tokio::spawn(spawn_system( - system_addr.parse()?, - Duration::from_secs(supervision_update_timeout_in_sec), - Duration::from_secs(world_eviction_timeout_in_sec), - )), - RunCommand::Host { - system_addr, - host_world, - host_rank, - supervision_update_interval_in_sec, - } => tokio::spawn(spawn_host( - system_addr.parse()?, - host_world.parse()?, - host_rank, - Duration::from_secs(supervision_update_interval_in_sec), - )), - RunCommand::Controller(ControllerCommand { - worker_world, - system_addr, - controller_actor_id, - world_size, - num_procs_per_host, - worker_name, - program, - supervision_query_interval_in_sec, - supervision_update_interval_in_sec, - worker_progress_check_interval_in_sec, - operation_timeout_in_sec, - operations_per_worker_progress_request, - is_cpu_worker, - extra_proc_labels, - fail_on_worker_timeout, - }) => tokio::spawn(spawn_controller( - system_addr.parse()?, - controller_actor_id.parse()?, - world_size, - num_procs_per_host, - worker_world.parse()?, - worker_name, - program, - Duration::from_secs(supervision_query_interval_in_sec), - Duration::from_secs(supervision_update_interval_in_sec), - Duration::from_secs(worker_progress_check_interval_in_sec), - Duration::from_secs(operation_timeout_in_sec), - operations_per_worker_progress_request, - is_cpu_worker, - extra_proc_labels, - fail_on_worker_timeout, - )), - }) -} - -/// Spawn the system actor -async fn spawn_system( - system_addr: ChannelAddr, - supervision_update_timeout: Duration, - world_eviction_timeout: Duration, -) -> anyhow::Result<()> { - tracing::info!("spawning system"); - - let handle = System::serve( - system_addr.clone(), - supervision_update_timeout, - world_eviction_timeout, - ) - .await?; - tracing::info!("system serve: {}", handle.local_addr()); - - // This will not end until the system actor is stopped. - handle.system_actor_handle().clone().await; - - tracing::info!("system actor exited"); - - Ok(()) -} - -/// Spawn the host actor -#[tracing::instrument(skip_all)] -async fn spawn_host( - system_addr: ChannelAddr, - host_world_id: WorldId, - host_rank: Index, - supervision_update_interval: Duration, -) -> anyhow::Result<()> { - tracing::info!("spawning host actor"); - - let proc_id = host_world_id.proc_id(host_rank); - let host_addr = ChannelAddr::any(system_addr.transport()); - - let bootstrap = ProcActor::bootstrap( - proc_id.clone(), - host_world_id.clone(), - host_addr, - system_addr, - supervision_update_interval, - HashMap::from([( - LABEL_NAME_ROLE.to_string(), - LABEL_VALUE_ROLE_HOST.to_string(), - )]), - ProcLifecycleMode::ManagedBySystem, - ) - .await?; - tracing::info!( - "{}: joined; host actor: {}", - proc_id, - bootstrap.proc_actor.actor_id() - ); - - // This will not end until the proc actor is stopped. - bootstrap.proc_actor.await; - - Ok(()) -} - -/// Spawn the controller actor. The order of bootstrap is: -/// 1. Create the new worker world. -/// 2. Check if the worker world is alive -/// 3. Spawn the controller proc and actor. -/// 4. Spawn all the worker actors and wait for them to be ready. -/// 5. Create the new controller world. The client is able to send traffic -/// only after both the controller and worker worlds are alive. -#[tracing::instrument(skip_all)] -async fn spawn_controller( - system_addr: ChannelAddr, - controller_actor_id: ActorId, - num_procs: usize, - num_procs_per_host: usize, - worker_world_id: WorldId, - worker_name: String, - program: Option, - supervision_query_interval: Duration, - supervision_update_interval: Duration, - worker_progress_check_interval: Duration, - operation_timeout: Duration, - operations_per_worker_progress_request: u64, - is_cpu_worker: bool, - extra_proc_labels: Option>, - fail_on_worker_timeout: bool, -) -> anyhow::Result<()> { - tracing::info!("spawning controller"); - - let mut system = hyperactor_multiprocess::System::new(system_addr.clone()); - let instance = system.attach().await.unwrap(); - - self::create_world( - &instance, - controller_actor_id.clone(), - num_procs, - num_procs_per_host, - worker_world_id.clone(), - program, - ) - .await?; - let handle = self::bootstrap_controller( - system_addr, - None, // listen_addr - controller_actor_id.clone(), - num_procs, - worker_world_id.clone(), - worker_name.clone(), - supervision_query_interval, - supervision_update_interval, - worker_progress_check_interval, - operation_timeout, - operations_per_worker_progress_request, - extra_proc_labels, - fail_on_worker_timeout, - ) - .await?; - - self::spawn_worker_actors( - &instance, - controller_actor_id.clone(), - num_procs, - worker_world_id, - worker_name, - is_cpu_worker, - ) - .await?; - - // Controller will join its own world. - // This will announce itself as live so the client can observe it. - system_actor::SYSTEM_ACTOR_REF - .upsert_world( - &instance, - WorldId(controller_actor_id.world_name().to_string()), - Shape::Definite(vec![1]), - 1, - Environment::Local, - HashMap::new(), - ) - .await?; - tracing::info!( - "created new controller world {}", - controller_actor_id.world_name() - ); - - // This will not end until the system actor is stopped. - handle.await; - - tracing::info!("controller actor exited"); - - Ok(()) -} - -/// Bootstraps the controller actor. -/// Listen address is optional. If not provided, it will be assigned with a random available -/// address that has the same transport as the system address. -pub async fn bootstrap_controller( - system_addr: ChannelAddr, - listen_addr: Option, - controller_actor_id: ActorId, - num_procs: usize, - worker_world_id: WorldId, - worker_name: String, - supervision_query_interval: Duration, - supervision_update_interval: Duration, - worker_progress_check_interval: Duration, - operation_timeout: Duration, - operations_per_worker_progress_request: u64, - extra_controller_labels: Option>, - fail_on_worker_timeout: bool, -) -> anyhow::Result> { - let listen_addr = listen_addr.unwrap_or(ChannelAddr::any(system_addr.transport())); - let mut controller_labels = HashMap::from([ - ( - LABEL_NAME_ROLE.to_string(), - LABEL_VALUE_ROLE_CONTROLLER.to_string(), - ), - ( - LABEL_NAME_WORKER_WORLD.to_string(), - worker_world_id.to_string(), - ), - ]); - tracing::info!("controller labels: {:?}", extra_controller_labels); - if let Some(extra_controller_labels) = extra_controller_labels { - controller_labels.extend(extra_controller_labels); - } - let (handle, actor_ref) = ControllerActor::bootstrap( - controller_actor_id.clone(), - listen_addr, - system_addr, - ControllerParams { - world_size: num_procs, - comm_actor_ref: ActorRef::::attest( - controller_actor_id.proc_id().actor_id(COMM_ACTOR_NAME, 0), - ), - worker_gang_ref: GangRef::attest(GangId(worker_world_id.clone(), worker_name.clone())), - supervision_query_interval, - worker_progress_check_interval, - operation_timeout, - operations_per_worker_progress_request, - fail_on_worker_timeout, - }, - supervision_update_interval, - controller_labels, - ) - .await?; - tracing::info!("controller starts with id: {}", actor_ref.actor_id()); - - Ok(handle) -} - -async fn create_world( - cx: &impl context::Actor, - controller_actor_id: ActorId, - num_procs: usize, - num_procs_per_host: usize, - worker_world_id: WorldId, - program: Option, -) -> anyhow::Result<()> { - system_actor::SYSTEM_ACTOR_REF - .upsert_world( - cx, - worker_world_id.clone(), - Shape::Definite(vec![num_procs]), - num_procs_per_host, - match program { - Some(program) => Environment::Exec { program }, - None => Environment::Local, - }, - HashMap::from([ - (LABEL_NAME_WORKER.to_string(), "1".to_string()), - ( - LABEL_NAME_CONTROLLER_ACTOR_ID.to_string(), - controller_actor_id.to_string(), - ), - ]), - ) - .await?; - tracing::info!("created new worker world {}", worker_world_id); - - // Wait for all the worker procs to join the worker world. - let timeout = hyperactor::config::global::get(hyperactor::config::MESSAGE_DELIVERY_TIMEOUT); - tracing::info!("waiting for worker world {} to be alive", worker_world_id); - loop { - let snapshot = RealClock - .timeout(timeout, async { - system_actor::SYSTEM_ACTOR_REF - .snapshot( - cx, - system_actor::SystemSnapshotFilter { - worlds: vec![worker_world_id.clone()], - world_labels: HashMap::new(), - proc_labels: HashMap::new(), - }, - ) - .await - }) - .await?; - let snapshot = snapshot?; - if let Some(world) = snapshot.worlds.get(&worker_world_id) { - if world.status.is_live() { - break; - } - } - RealClock.sleep(Duration::from_millis(10)).await; - } - tracing::info!( - "worker world {} is alive; spawning {} worker actors", - worker_world_id, - num_procs - ); - Ok(()) -} - -async fn spawn_worker_actors( - cx: &impl context::Actor, - controller_actor_id: ActorId, - num_procs: usize, - worker_world_id: WorldId, - worker_name: String, - is_cpu_worker: bool, -) -> anyhow::Result<()> { - // Bootstrap worker actors and wait for them to be ready. - let (spawned_port, mut spawned_receiver) = open_port(cx); - for rank in 0..num_procs { - let param = WorkerParams { - world_size: num_procs, - // Rank assignment is consistent with proc indices. - rank, - // TODO: We never use device index during Monarch bootstrap. - // Instead, CUDA_VISIBLE_DEVICES is used for workers to access CUDA devices. - device_index: if is_cpu_worker { None } else { Some(0) }, - controller_actor: ActorRef::attest(controller_actor_id.clone()), - }; - let worker_proc = - ActorRef::::attest(worker_world_id.proc_id(rank).actor_id("proc", 0)); - - worker_proc - .spawn( - cx, - // Use explicit actor type to avoid the WorkActor dependency. - "monarch_tensor_worker::WorkerActor".to_owned(), - worker_name.clone(), - bincode::serialize(¶m)?, - spawned_port.bind(), - ) - .await?; - } - let mut spawned = HashSet::new(); - while spawned.len() < num_procs { - spawned.insert(spawned_receiver.recv().await?); - } - tracing::info!("spawned {} worker actors", num_procs); - - Ok(()) -} - -pub fn parse_key_val(s: &str) -> anyhow::Result<(String, String)> { - match s.split_once('=') { - None => Err(anyhow::anyhow!("invalid KEY=value: no `=` found in `{s}`")), - Some((a, b)) => Ok((a.to_owned(), b.to_owned())), - } -} - -pub fn register_python_bindings(controller_mod: &Bound<'_, PyModule>) -> PyResult<()> { - controller_mod.add_class::()?; - controller_mod.add_class::()?; - controller_mod.add_class::()?; - controller_mod.add_class::()?; - Ok(()) -} diff --git a/controller/src/history.rs b/controller/src/history.rs deleted file mode 100644 index 4405869ac..000000000 --- a/controller/src/history.rs +++ /dev/null @@ -1,842 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::cmp::Ordering; -use std::collections::BTreeMap; -use std::collections::HashMap; -use std::collections::HashSet; - -use hyperactor::clock::Clock; -use hyperactor::data::Serialized; -use monarch_messages::client::Exception; -use monarch_messages::controller::Seq; -use monarch_messages::controller::WorkerError; -use monarch_messages::worker::Ref; - -/// An invocation tracks a discrete node in the graph of operations executed by -/// the worker based on instructions from the client. -/// It is useful for tracking the dependencies of an operation and propagating -/// failures. In the future this will be used with more data dependency tracking -/// to support better failure handling. -// Allowing dead code until we do something smarter with defs, uses etc. -#[allow(dead_code)] -#[derive(Debug)] -struct Invocation { - /// The sequence number of the invocation. This should be unique and increasing across all - /// invocations. - seq: Seq, - /// The references that this invocation defines or redefines. Effectively the - /// output of the invocation. - defs: Vec, - /// The references that this invocation uses. Effectively the input of the invocation. - uses: Vec, - /// The result of the invocation. This is set when the invocation is completed or - /// when a failure is inferred. A successful result will always supersede any failure. - result: Option>, - /// The seqs for the invocations that depend on this invocation. Useful for propagating failures. - users: HashSet, - /// If we have reported the result eagerly, we want to make sure to not double report. This also - /// lets us know when we can stop traversing when finding unreported results - reported: bool, -} - -impl Invocation { - fn new(seq: Seq, uses: Vec, defs: Vec) -> Self { - Self { - seq, - uses, - defs, - result: None, - users: HashSet::new(), - reported: false, - } - } - - fn add_user(&mut self, user: Seq) { - self.users.insert(user); - } - - /// Invocation results can only go from valid to failed, or be - /// set if the invocation result is empty. - fn set_result(&mut self, result: Result) { - if self.result.is_none() || matches!((&self.result, &result), (Some(Ok(_)), Err(_))) { - self.result = Some(result); - } - } - - fn set_exception(&mut self, exception: Exception) { - match exception { - Exception::Error(_, caused_by, error) => { - self.set_result(Err(Exception::Error(self.seq, caused_by, error))); - } - Exception::Failure(_) => { - tracing::error!( - "system failures {:?} can never be assigned for an invocation", - exception - ); - } - } - } - - fn exception(&self) -> Option<&Exception> { - self.result - .as_ref() - .map(Result::as_ref) - .and_then(Result::err) - } - - #[allow(dead_code)] - fn value(&self) -> Option<&Serialized> { - self.result - .as_ref() - .map(Result::as_ref) - .and_then(Result::ok) - } -} - -#[derive(Debug, PartialEq)] -enum RefStatus { - // The invocation for this ref is still in progress. - Invoked(Seq), - // The invocation for this ref has errored. - Errored(Exception), -} - -/// The history of invocations sent by the client to be executed on the workers. -/// This is used to track dependencies between invocations and to propagate exceptions. -/// It purges history for completed invocations to avoid memory bloat. -/// TODO: Revisit this setup around purging refs automatically once we start doing -/// more complex data dependency tracking. We will want to be more aware of things like -/// borrows, drops etc. directly. -#[derive(Debug)] -#[allow(dead_code)] -pub struct History { - /// The first incomplete Seq for each rank. This is used to determine which - /// Seqs are no longer relevant and can be purged from the history. - first_incomplete_seqs: MinVector, - /// The minimum incomplete Seq across all ranks. - min_incomplete_seq: Seq, - /// A map of seq to the invocation that it represents. - invocations: HashMap, - /// A map of reference to the seq for the invocation that defines it. This is used to - /// compute dependencies between invocations. - invocation_for_ref: HashMap, - // Refs to be deleted in mark_worker_complete_and_propagate_failures - marked_for_deletion: HashSet, - // Last seq to be invoked - max_seq: OptionSeq, - // The first incompleted Seq for each rank derived from both client and controller request_status messages - // This is needed because the client may retain invocations past the time completed such as in call_fetch_shard().result() - first_incomplete_seqs_controller: MinVector, - // Memoized minimum incompleted Seq across all ranks of first_incomplete_seqs_controller - min_incompleted_seq_controller: Seq, - // The deadline for the next expected completed seq. This is updated only when the previous deadline - // has been met. - // - // Tuple fields are: - // - the seq we expect to be completed - // - the deadline - // - if it has already been reported to the client - deadline: Option<(Seq, tokio::time::Instant, bool)>, -} - -/// A vector that keeps track of the minimum value. -#[derive(Debug)] -struct MinVector { - data: Vec, - value_counts: BTreeMap, -} - -impl MinVector -where - T: Ord + Copy, -{ - fn new(data: Vec) -> Self { - let mut value_counts = BTreeMap::new(); - for &value in &data { - *value_counts.entry(value).or_insert(0) += 1; - } - MinVector { data, value_counts } - } - - fn set(&mut self, index: usize, value: T) { - // Decrease the count of the old value - let old_value = self.data[index]; - if let Some(count) = self.value_counts.get_mut(&old_value) { - *count -= 1; - if *count == 0 { - self.value_counts.remove(&old_value); - } - } - // Update the value in the vector - self.data[index] = value; - - // Increase the count of the new value - *self.value_counts.entry(value).or_insert(0) += 1; - } - - fn get(&self, index: usize) -> T { - self.data[index] - } - - fn min(&self) -> T { - *self.value_counts.keys().next().unwrap() - } - - fn len(&self) -> usize { - self.data.len() - } - - fn vec(&self) -> &Vec { - &self.data - } -} - -impl History { - pub fn new(world_size: usize) -> Self { - Self { - first_incomplete_seqs: MinVector::new(vec![Seq::default(); world_size]), - min_incomplete_seq: Seq::default(), - invocation_for_ref: HashMap::new(), - invocations: HashMap::new(), - marked_for_deletion: HashSet::new(), - max_seq: OptionSeq::from(None), - first_incomplete_seqs_controller: MinVector::new(vec![Seq::default(); world_size]), - min_incompleted_seq_controller: Seq::default(), - deadline: None, - } - } - - #[cfg(test)] - pub fn first_incomplete_seqs(&self) -> &[Seq] { - self.first_incomplete_seqs.vec() - } - - pub fn first_incomplete_seqs_controller(&self) -> &[Seq] { - self.first_incomplete_seqs_controller.vec() - } - - pub fn min_incomplete_seq_reported(&self) -> Seq { - self.min_incompleted_seq_controller - } - - pub fn world_size(&self) -> usize { - self.first_incomplete_seqs.len() - } - - pub fn delete_invocations_for_refs(&mut self, refs: Vec) { - self.marked_for_deletion.extend(refs); - - self.marked_for_deletion - .retain(|ref_| match self.invocation_for_ref.get(ref_) { - Some(RefStatus::Invoked(seq)) => { - if seq < &self.min_incomplete_seq { - self.invocation_for_ref.remove(ref_); - false - } else { - true - } - } - Some(RefStatus::Errored(_)) => { - self.invocation_for_ref.remove(ref_); - false - } - None => true, - }); - } - - /// Add an invocation to the history. - pub fn add_invocation( - &mut self, - seq: Seq, - uses: Vec, - defs: Vec, - ) -> Vec<(Seq, Option>)> { - let mut results = Vec::new(); - let input_seq = OptionSeq::from(seq); - assert!( - input_seq >= self.max_seq, - "nonmonotonic seq: {:?}; current max: {:?}", - seq, - self.max_seq, - ); - self.max_seq = input_seq; - let mut invocation = Invocation::new(seq, uses.clone(), defs.clone()); - - for use_ in uses { - // The invocation for every use_ should add this seq as a user. - match self.invocation_for_ref.get(&use_) { - Some(RefStatus::Errored(exception)) => { - // We know that this invocation hasn't been completed yet, so we can - // directly call set_exception on it. - if !invocation.reported { - invocation.set_exception(exception.clone()); - results.push((seq, Some(Err(exception.clone())))); - invocation.reported = true; - } - } - Some(RefStatus::Invoked(invoked_seq)) => { - if let Some(invocation) = self.invocations.get_mut(invoked_seq) { - invocation.add_user(seq) - } - } - None => tracing::debug!( - "ignoring dependency on potentially complete invocation for ref: {:?}", - use_ - ), - } - } - for def in defs { - self.invocation_for_ref.insert( - def, - match invocation.exception() { - Some(err) => RefStatus::Errored(err.clone()), - None => RefStatus::Invoked(seq.clone()), - }, - ); - } - - self.invocations.insert(seq, invocation); - - results - } - - /// Propagate worker error to the invocation with the given Seq. This will also propagate - /// to all seqs that depend on this seq directly or indirectly. - pub fn propagate_exception(&mut self, seq: Seq, exception: Exception) { - let mut queue = vec![seq]; - let mut visited = HashSet::new(); - - while let Some(seq) = queue.pop() { - if !visited.insert(seq) { - continue; - } - - let Some(invocation) = self.invocations.get_mut(&seq) else { - continue; - }; - - // Overwrite the error, so we are using the last error for this invocation to send - // to the client. - for def in invocation.defs.iter() { - match self.invocation_for_ref.get(def) { - Some(RefStatus::Invoked(invoked_seq)) if *invoked_seq == seq => self - .invocation_for_ref - .insert(*def, RefStatus::Errored(exception.clone())), - _ => None, - }; - } - invocation.set_exception(exception.clone()); - queue.extend(invocation.users.iter()); - } - } - - fn find_unreported_dependent_exceptions( - &mut self, - seq: Seq, - ) -> Vec<(Seq, Option>)> { - let mut queue = vec![seq]; - let mut visited = HashSet::new(); - let mut results = Vec::new(); - - while let Some(seq) = queue.pop() { - if !visited.insert(seq) { - continue; - } - - let Some(invocation) = self.invocations.get_mut(&seq) else { - continue; - }; - - if !matches!(invocation.result, Some(Err(_))) || invocation.reported { - continue; - } - - invocation.reported = true; - - results.push((seq, invocation.result.clone())); - - queue.extend(invocation.users.iter()); - } - results - } - - pub fn report_deadline_missed(&mut self) { - if let Some((seq, time, _)) = self.deadline { - self.deadline = Some((seq, time, true)); - } - } - - pub fn deadline( - &mut self, - expected_progress: u64, - timeout: tokio::time::Duration, - clock: &impl Clock, - ) -> Option<(Seq, tokio::time::Instant, bool)> { - let previous_deadline_completed = match self.deadline { - Some((expected_seq, ..)) => self.min_incompleted_seq_controller > expected_seq, - None => self.max_seq.inner().is_some(), - }; - - if previous_deadline_completed { - let next_expected_completed_seq = std::cmp::min( - OptionSeq::from(u64::from(self.min_incompleted_seq_controller) + expected_progress), - self.max_seq.clone(), - ); - - self.deadline = - next_expected_completed_seq - .into_inner() - .map(|next_expected_completed_seq| { - (next_expected_completed_seq, clock.now() + timeout, false) - }); - } - self.deadline - } - - pub fn update_deadline_tracking(&mut self, rank: usize, seq: Seq) { - // rank_completed also calls this so that we stay up to date with client request_status messages. - // However, controller request_status messages may be ahead of the client as the client may retain invocations - // past the time completed so we should take the max - self.first_incomplete_seqs_controller.set( - rank, - std::cmp::max(seq, self.first_incomplete_seqs.get(rank)), - ); - - self.min_incompleted_seq_controller = self.first_incomplete_seqs_controller.min(); - } - - /// Mark the given rank as completed up to but excluding the given Seq. This will also purge history for - /// any Seqs that are no longer relevant (completed on all ranks). - pub fn rank_completed( - &mut self, - rank: usize, - seq: Seq, - ) -> Vec<(Seq, Option>)> { - self.first_incomplete_seqs.set(rank, seq); - let prev = self.min_incomplete_seq; - self.min_incomplete_seq = self.first_incomplete_seqs.min(); - self.update_deadline_tracking(rank, seq); - - let mut results: Vec<(Seq, Option>)> = Vec::new(); - for i in Seq::iter_between(prev, self.min_incomplete_seq) { - if let Some(invocation) = self.invocations.remove(&i) { - let retain = if let Some(result) = invocation.result { - let is_err = result.is_err(); - if !invocation.reported { - results.push((i, Some(result))); - } - is_err - } else { - // Do not retain successful invocations. - results.push((i, None)); - false - }; - - if retain { - // Retain the def history because we may need it to propagate - // errors in the future. We rely here on the fact that the invocation - // above has been marked as failed by way of failure propagation. - for def in &invocation.defs { - match self.invocation_for_ref.get(def) { - Some(RefStatus::Invoked(seq)) if *seq == i => { - self.invocation_for_ref.remove(def) - } - _ => None, - }; - } - } - } - } - - // Propagate results to the client even if it is behind the completion frontier - // if we can determine for sure that it is completed - results.extend(self.find_unreported_dependent_exceptions(seq)); - - results - } - - #[cfg(test)] - fn get_invocation(&self, seq: Seq) -> Option<&Invocation> { - self.invocations.get(&seq) - } - - pub fn set_result(&mut self, seq: Seq, result: Result) { - if let Some(invocation) = self.invocations.get_mut(&seq) { - invocation.set_result(result.map_err(|e| Exception::Error(seq, seq, e))); - } - } -} - -/// Struct representing an optional `Seq`, where `None` is always considered the -/// smallest. This type is to make it easier to compare `Option` with `Seq` -/// or `Option`. -#[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) struct OptionSeq(Option); - -impl OptionSeq { - /// Return inner ref. - pub fn inner(&self) -> &Option { - &self.0 - } - - /// Return inner. - pub fn into_inner(self) -> Option { - self.0 - } -} - -impl PartialOrd for OptionSeq { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for OptionSeq { - fn cmp(&self, other: &Self) -> Ordering { - match (self.0, other.0) { - (Some(a), Some(b)) => a.cmp(&b), - (Some(_), None) => Ordering::Greater, - (None, Some(_)) => Ordering::Less, - (None, None) => Ordering::Equal, - } - } -} - -impl From for OptionSeq { - fn from(value: u64) -> Self { - OptionSeq(Some(Seq::from(value))) - } -} - -impl From for OptionSeq { - fn from(seq: Seq) -> Self { - OptionSeq(Some(seq)) - } -} - -impl From> for OptionSeq { - fn from(value: Option) -> Self { - OptionSeq(value) - } -} - -#[cfg(test)] -mod tests { - use std::assert_matches::assert_matches; - - use hyperactor::id; - - use super::*; - - struct InvocationUsersIterator<'a> { - history: &'a History, - stack: Vec<&'a Invocation>, - visited: HashSet, - } - - impl<'a> Iterator for InvocationUsersIterator<'a> { - type Item = &'a Invocation; - - fn next(&mut self) -> Option { - while let Some(invocation) = self.stack.pop() { - if !self.visited.insert(invocation.seq) { - continue; - } - self.stack.extend( - invocation - .users - .iter() - .filter_map(|seq| self.history.invocations.get(seq)), - ); - return Some(invocation); - } - None - } - } - - impl History { - /// Get an iterator of Seqs that are users of (or dependent on the completion of) the given Seq. - /// This is useful for propagating failures. This will return an empty iterator if the given Seq is - /// not in the history. So this should be called before the invocation is marked as completed for the - /// given rank. - /// The Seq passed to this function will also be included in the iterator. - pub(crate) fn iter_users_transitive(&self, seq: Seq) -> impl Iterator + '_ { - let invocations = self - .invocations - .get(&seq) - .map_or(Vec::default(), |invocation| vec![invocation]); - - InvocationUsersIterator { - history: self, - stack: invocations, - visited: HashSet::new(), - } - .map(|invocation| invocation.seq) - } - } - - #[test] - fn simple_history() { - let mut history = History::new(2); - history.add_invocation(0.into(), vec![], vec![Ref { id: 1 }, Ref { id: 2 }]); - history.add_invocation(1.into(), vec![Ref { id: 1 }], vec![Ref { id: 3 }]); - history.add_invocation(2.into(), vec![Ref { id: 3 }], vec![Ref { id: 4 }]); - history.add_invocation(3.into(), vec![Ref { id: 3 }], vec![Ref { id: 5 }]); - history.add_invocation(4.into(), vec![Ref { id: 3 }], vec![Ref { id: 6 }]); - history.add_invocation(5.into(), vec![Ref { id: 4 }], vec![Ref { id: 7 }]); - history.add_invocation(6.into(), vec![Ref { id: 4 }], vec![Ref { id: 8 }]); - - let mut res = history - .iter_users_transitive(1.into()) - .collect::>(); - res.sort(); - assert_eq!( - res, - vec![1.into(), 2.into(), 3.into(), 4.into(), 5.into(), 6.into()] - ); - - history.rank_completed(0, 2.into()); - let mut res = history - .iter_users_transitive(1.into()) - .collect::>(); - res.sort(); - assert_eq!( - res, - vec![1.into(), 2.into(), 3.into(), 4.into(), 5.into(), 6.into()] - ); - - history.rank_completed(1, 2.into()); - let res = history - .iter_users_transitive(1.into()) - .collect::>(); - assert_eq!(res, vec![]); - - // Test that we can still add invocations after all ranks have completed that seq. - history.add_invocation(7.into(), vec![Ref { id: 1 }], vec![]); - } - - #[test] - fn delete_errored_invocations() { - let mut history = History::new(1); - history.add_invocation(0.into(), vec![], vec![Ref { id: 1 }, Ref { id: 2 }]); - history.add_invocation(1.into(), vec![Ref { id: 1 }], vec![Ref { id: 3 }]); - history.propagate_exception( - 0.into(), - Exception::Error( - 0.into(), - 0.into(), - WorkerError { - backtrace: "worker error happened".to_string(), - worker_actor_id: id!(test[234].testactor[6]), - }, - ), - ); - history.delete_invocations_for_refs(vec![Ref { id: 1 }, Ref { id: 2 }]); - history.rank_completed(0, 1.into()); - assert_eq!(history.invocation_for_ref.len(), 1); - history.delete_invocations_for_refs(vec![Ref { id: 3 }]); - history.rank_completed(0, 2.into()); - assert!(history.invocation_for_ref.is_empty()); - } - - #[test] - fn redefinitions() { - let mut history = History::new(2); - history.add_invocation(0.into(), vec![], vec![Ref { id: 1 }, Ref { id: 2 }]); - history.add_invocation(1.into(), vec![Ref { id: 1 }], vec![Ref { id: 3 }]); - history.add_invocation(2.into(), vec![Ref { id: 3 }], vec![Ref { id: 4 }]); - - let mut res = history - .iter_users_transitive(1.into()) - .collect::>(); - res.sort(); - assert_eq!(res, vec![1.into(), 2.into()]); - - history.add_invocation(3.into(), vec![Ref { id: 3 }], vec![Ref { id: 3 }]); - history.add_invocation(4.into(), vec![Ref { id: 3 }], vec![Ref { id: 6 }]); - history.add_invocation(5.into(), vec![Ref { id: 4 }], vec![Ref { id: 7 }]); - history.add_invocation(6.into(), vec![Ref { id: 4 }], vec![Ref { id: 8 }]); - - history.rank_completed(0, 2.into()); - history.rank_completed(1, 2.into()); - - let res = history - .iter_users_transitive(3.into()) - .collect::>(); - assert_eq!(res, vec![3.into(), 4.into()]); - } - - #[test] - fn min_vector() { - // Test initialization - let data = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]; - let mut min_vector = MinVector::new(data.clone()); - - // Test length - assert_eq!(min_vector.len(), data.len()); - - // Test initial vector - assert_eq!(min_vector.vec(), &data); - - // Test initial minimum - assert_eq!(min_vector.min(), 1); - - // Test get method - for (i, &value) in data.iter().enumerate() { - assert_eq!(min_vector.get(i), value); - } - - // Test set method and min update - min_vector.set(0, 0); // Change first element to 0 - assert_eq!(min_vector.get(0), 0); - assert_eq!(min_vector.min(), 0); - min_vector.set(1, 7); // Change second element to 7 - assert_eq!(min_vector.get(1), 7); - assert_eq!(min_vector.min(), 0); - min_vector.set(0, 8); // Change first element to 8 - assert_eq!(min_vector.get(0), 8); - assert_eq!(min_vector.min(), 1); // Minimum should now be 1 - - // Test setting a value that already exists - min_vector.set(2, 5); // Change third element to 5 - assert_eq!(min_vector.get(2), 5); - assert_eq!(min_vector.min(), 1); - - // Test setting a value to the current minimum - min_vector.set(3, 0); // Change fourth element to 0 - assert_eq!(min_vector.get(3), 0); - assert_eq!(min_vector.min(), 0); - - // Test setting all elements to the same value - for i in 0..min_vector.len() { - min_vector.set(i, 10); - } - assert_eq!(min_vector.min(), 10); - assert_eq!(min_vector.vec(), &vec![10; min_vector.len()]); - } - - #[test] - fn failure_propagation() { - let mut history = History::new(2); - - history.add_invocation(0.into(), vec![], vec![Ref { id: 1 }, Ref { id: 2 }]); - history.add_invocation(1.into(), vec![Ref { id: 1 }], vec![Ref { id: 3 }]); - history.add_invocation( - 2.into(), - vec![Ref { id: 3 }], - vec![Ref { id: 4 }, Ref { id: 5 }], - ); - history.add_invocation(3.into(), vec![Ref { id: 2 }], vec![Ref { id: 6 }]); - history.add_invocation(4.into(), vec![Ref { id: 5 }], vec![Ref { id: 6 }]); - - // No error before propagation - for i in 1..=3 { - assert!( - history - .get_invocation(i.into()) - .unwrap() - .exception() - .is_none() - ); - } - - // Failure happened to invocation 1, invocations 2, 4 should be marked as failed because they - // depend on 1 directly or indirectly. - history.propagate_exception( - 1.into(), - Exception::Error( - 1.into(), - 1.into(), - WorkerError { - backtrace: "worker error happened".to_string(), - worker_actor_id: "test[234].testactor[6]".parse().unwrap(), - }, - ), - ); - - // Error should be set for all invocations that depend on the failed invocation - for i in [1, 2, 4] { - assert!( - history - .get_invocation(i.into()) - .unwrap() - .exception() - .is_some() - ); - } - - // Error should not be set for invocations that do not depend on the failed invocation - for i in [0, 3] { - assert!( - history - .get_invocation(i.into()) - .unwrap() - .exception() - .is_none() - ); - } - - // A failed but completed invocation should still lead to all its - // invocations being marked as failed even if they appear in the future. - - // Delete until 2. - history.rank_completed(0, 2.into()); - history.rank_completed(1, 2.into()); - - for i in [3, 4, 5, 6] { - assert_matches!( - history.invocation_for_ref.get(&i.into()), - Some(RefStatus::Errored(_)), - ); - // Invocation should start from 5, so i+2 - history.add_invocation((i + 2).into(), vec![Ref { id: i }], vec![Ref { id: 7 }]); - assert!( - history - .get_invocation((i + 2).into()) - .unwrap() - .exception() - .is_some() - ); - } - - // Test if you can fill a valid result on an errored invocation 2. - history.set_result( - 2.into(), - Ok(Serialized::serialize(&"2".to_string()).unwrap()), - ); - // check that seq 2 is still errored - assert!( - history - .get_invocation((2).into()) - .unwrap() - .exception() - .is_some() - ); - assert!( - history - .get_invocation((2).into()) - .unwrap() - .value() - .is_none() - ); - } - - #[test] - fn test_option_seq_comparision() { - assert_eq!(OptionSeq::from(None), OptionSeq::from(None)); - assert_eq!(OptionSeq::from(1), OptionSeq::from(Seq::from(1))); - assert_eq!(OptionSeq::from(1), OptionSeq::from(Some(Seq::from(1)))); - - assert!(OptionSeq::from(None) < OptionSeq::from(0)); - assert!(OptionSeq::from(0) < OptionSeq::from(1)); - - assert!(OptionSeq::from(0) > OptionSeq::from(None)); - assert!(OptionSeq::from(1) > OptionSeq::from(0)); - } -} diff --git a/controller/src/lib.rs b/controller/src/lib.rs deleted file mode 100644 index d74bae600..000000000 --- a/controller/src/lib.rs +++ /dev/null @@ -1,1944 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#![feature(assert_matches)] -// NOTE: Until https://github.com/PyO3/pyo3/pull/4674, `pyo3::pymethods` trigger -// and unsafe-op-in-unsafe-fn warnings. -#![allow(unsafe_op_in_unsafe_fn)] - -pub mod bootstrap; -pub mod history; - -use std::collections::HashMap; -use std::collections::HashSet; -use std::time::Duration; - -use async_trait::async_trait; -use hyperactor::Actor; -use hyperactor::ActorId; -use hyperactor::ActorRef; -use hyperactor::Context; -use hyperactor::GangId; -use hyperactor::GangRef; -use hyperactor::Handler; -use hyperactor::Named; -use hyperactor::actor::ActorHandle; -use hyperactor::actor::ActorStatus; -use hyperactor::channel::ChannelAddr; -use hyperactor::clock::Clock; -use hyperactor::context; -use hyperactor::data::Serialized; -use hyperactor_mesh::comm::CommActor; -use hyperactor_mesh::comm::CommActorMode; -use hyperactor_mesh::comm::multicast::CastMessage; -use hyperactor_mesh::comm::multicast::CastMessageEnvelope; -use hyperactor_mesh::comm::multicast::DestinationPort; -use hyperactor_mesh::comm::multicast::Uslice; -use hyperactor_mesh::reference::ActorMeshId; -use hyperactor_mesh::reference::ProcMeshId; -use hyperactor_multiprocess::proc_actor::ProcActor; -use hyperactor_multiprocess::proc_actor::spawn; -use hyperactor_multiprocess::supervision::WorldSupervisionMessageClient; -use hyperactor_multiprocess::supervision::WorldSupervisor; -use hyperactor_multiprocess::system_actor::ProcLifecycleMode; -use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF; -use monarch_messages::client::ClientActor; -use monarch_messages::client::ClientMessageClient; -use monarch_messages::client::Exception; -use monarch_messages::client::LogLevel; -use monarch_messages::controller::ControllerMessage; -use monarch_messages::controller::ControllerMessageHandler; -use monarch_messages::controller::DeviceFailure; -use monarch_messages::controller::Ranks; -use monarch_messages::controller::Seq; -use monarch_messages::controller::WorkerError; -use monarch_messages::debugger::DebuggerAction; -use monarch_messages::worker::Ref; -use monarch_messages::worker::WorkerActor; -use monarch_messages::worker::WorkerMessage; -use ndslice::Selection; -use ndslice::Shape; -use ndslice::Slice; -use ndslice::reshape::Limit; -use ndslice::reshape::ReshapeShapeExt; -use ndslice::selection::dsl; -use ndslice::shape::Range; -use serde::Deserialize; -use serde::Serialize; -use tokio::sync::OnceCell; - -const CASTING_FANOUT_SIZE: usize = 8; - -/// A controller for the workers that will be leveraged by the client to do the actual -/// compute tasks. This acts a proxy managing comms with the workers and handling things like history, -/// data dependency, worker lifecycles etc for the client abstracting it away. -#[derive(Debug)] -#[hyperactor::export( - spawn = true, - handlers = [ - ControllerMessage, - ], -)] -pub(crate) struct ControllerActor { - client_actor_ref: OnceCell>, - comm_actor_ref: ActorRef, - worker_gang_ref: GangRef, - history: history::History, - supervision_query_interval: Duration, - system_supervision_actor_ref: ActorRef, - worker_progress_check_interval: Duration, - operation_timeout: Duration, - operations_per_worker_progress_request: u64, - // The Seq and time we last sent out a WorkerMessage::RequestStatus. - last_controller_request_status: Option<(Seq, tokio::time::Instant)>, - fail_on_worker_timeout: bool, - world_size: usize, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Named)] -pub(crate) struct ControllerParams { - /// The world size to track the size of all the workers. - pub(crate) world_size: usize, - - /// Reference to the comm actor. It must be configured to target - /// the worker gang. The controller takes "ownership" of this actor: - /// it is immediately configured to target the worker gang. - /// This is a temporary workaround until we are fully on meshes. - pub(crate) comm_actor_ref: ActorRef, - - /// Reference to the workers to send commands to. - pub(crate) worker_gang_ref: GangRef, - - // How often to query world supervision status against system actor. - pub(crate) supervision_query_interval: Duration, - - // How often to query for if workers are making progress. - pub(crate) worker_progress_check_interval: Duration, - - // How long to wait for an operation to complete before considering it timed out. - pub(crate) operation_timeout: Duration, - - // How many operations are enqueued before we request a progress update on workers. - pub(crate) operations_per_worker_progress_request: u64, - - // If a failure should be propagated back to the client if workers are detected to be stuck. - pub(crate) fail_on_worker_timeout: bool, -} - -#[async_trait] -impl Actor for ControllerActor { - type Params = ControllerParams; - - async fn new(params: ControllerParams) -> Result { - Ok(Self { - client_actor_ref: OnceCell::new(), - comm_actor_ref: params.comm_actor_ref, - worker_gang_ref: params.worker_gang_ref, - history: history::History::new(params.world_size), - supervision_query_interval: params.supervision_query_interval, - system_supervision_actor_ref: ActorRef::attest(SYSTEM_ACTOR_REF.actor_id().clone()), - worker_progress_check_interval: params.worker_progress_check_interval, - operation_timeout: params.operation_timeout, - operations_per_worker_progress_request: params.operations_per_worker_progress_request, - last_controller_request_status: None, - fail_on_worker_timeout: params.fail_on_worker_timeout, - world_size: params.world_size, - }) - } - - async fn init(&mut self, cx: &hyperactor::Instance) -> Result<(), anyhow::Error> { - self.comm_actor_ref.send( - cx, - CommActorMode::ImplicitWithWorldId(self.worker_gang_ref.gang_id().world_id().clone()), - )?; - Ok(()) - } -} - -impl ControllerActor { - /// Bootstrap the controller actor. This will create a new proc, join the system at `bootstrap_addr` - /// and spawn the controller actor into the proc. `labels` is an arbitrary set of name/value pairs - /// to be attached to the proc in system registry which can be used later to query and find the proc(s) - /// using system's snapshot api. - pub async fn bootstrap( - controller_id: ActorId, - listen_addr: ChannelAddr, - bootstrap_addr: ChannelAddr, - params: ControllerParams, - supervision_update_interval: Duration, - labels: HashMap, - ) -> Result<(ActorHandle, ActorRef), anyhow::Error> { - let bootstrap = ProcActor::bootstrap( - controller_id.proc_id().clone(), - controller_id - .proc_id() - .world_id() - .expect("multiprocess supports only ranked procs") - .clone(), // REFACTOR(marius): make world_id a parameter of ControllerActor::bootstrap - listen_addr, - bootstrap_addr.clone(), - supervision_update_interval, - labels, - ProcLifecycleMode::ManagedBySystem, - ) - .await?; - - let mut system = hyperactor_multiprocess::System::new(bootstrap_addr); - let client = system.attach().await?; - - let controller_actor_ref = spawn::( - &client, - &bootstrap.proc_actor.bind(), - controller_id.clone().name(), - &ControllerParams { - comm_actor_ref: bootstrap.comm_actor.bind(), - ..params - }, - ) - .await?; - - Ok((bootstrap.proc_actor, controller_actor_ref)) - } - - fn client(&self) -> Result, anyhow::Error> { - self.client_actor_ref - .get() - .ok_or_else(|| anyhow::anyhow!("client actor ref not set")) - .cloned() - } - - // Send a request_status for the seq we expect to complete by our next deadline if it is more than - // N ops ahead of our last request_status, or if M seconds passed where: - // - // N = self.operations_per_worker_progress_request - // M = self.worker_progress_check_interval - async fn request_status_if_needed( - &mut self, - cx: &Context<'_, Self>, - ) -> Result<(), anyhow::Error> { - if let Some((expected_seq, ..)) = self.history.deadline( - self.operations_per_worker_progress_request, - self.operation_timeout, - cx.clock(), - ) { - if self.last_controller_request_status.is_none_or( - |(last_requested_seq, last_requested_time)| { - (expected_seq - >= (u64::from(last_requested_seq) - + self.operations_per_worker_progress_request) - .into() - || last_requested_time.elapsed() > self.worker_progress_check_interval) - && last_requested_seq != expected_seq - }, - ) { - // Send to all workers. - self.send( - cx, - Ranks::Slice( - ndslice::Slice::new(0, vec![self.history.world_size()], vec![1]).unwrap(), - ), - Serialized::serialize(&WorkerMessage::RequestStatus { - seq: expected_seq.clone(), - controller: true, - }) - .unwrap(), - ) - .await?; - - self.last_controller_request_status = - Some((expected_seq.clone(), cx.clock().now())); - } - } - - Ok(()) - } -} - -#[derive(Debug)] -struct CheckWorkerProgress; - -#[async_trait] -impl Handler for ControllerActor { - async fn handle( - &mut self, - cx: &Context, - _check_worker_progress: CheckWorkerProgress, - ) -> Result<(), anyhow::Error> { - let client = self.client()?; - - if let Some((expected_seq, deadline, reported)) = self.history.deadline( - self.operations_per_worker_progress_request, - self.operation_timeout, - cx.clock(), - ) { - if !reported - && cx.clock().now() > deadline - && expected_seq >= self.history.min_incomplete_seq_reported() - { - let timed_out_ranks = self - .history - .first_incomplete_seqs_controller() - .iter() - .enumerate() - .filter(|(_, seq)| seq <= &&expected_seq) - .map(|(rank, _)| rank) - .collect::>(); - - let failed_rank = timed_out_ranks.first().unwrap().clone(); - - let timed_out_ranks_string = timed_out_ranks - .into_iter() - .map(|rank| rank.to_string()) - .collect::>() - .join(", "); - - let message = format!( - "ranks {} have operations that have not completed after {} seconds", - timed_out_ranks_string, - self.operation_timeout.as_secs() - ); - if client - .log(cx, LogLevel::Warn, message.clone()) - .await - .is_ok() - { - self.history.report_deadline_missed(); - } - - if self.fail_on_worker_timeout { - client - .result( - cx, - expected_seq, - Some(Err(Exception::Failure(DeviceFailure { - actor_id: self.worker_gang_ref.rank(failed_rank).actor_id().clone(), - address: "unknown".into(), - backtrace: message, - }))), - ) - .await?; - } - } - self.request_status_if_needed(cx).await?; - } - - cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?; - Ok(()) - } -} - -/// Hacky translation from a sub-`Slice` to a `Selection. -fn slice_to_selection(slice: Slice) -> Selection { - match (slice.sizes(), slice.strides()) { - // Special case exact rank `Selection`. - ([], []) => dsl::range(slice.offset()..=slice.offset(), dsl::true_()), - // Special case trivial range `Selection`. - ([size, rsizes @ ..], [stride, ..]) if rsizes.iter().all(|s| *s == 1) => dsl::range( - Range( - slice.offset(), - Some(slice.offset() + *size * *stride), - *stride, - ), - dsl::true_(), - ), - // Fallback to more heavy-weight translation for everything else. - _ => { - let mut selection = Selection::False; - let mut selected_ranks = HashSet::new(); - for rank in slice.iter() { - if !selected_ranks.insert(rank) { - continue; - } - selection = dsl::union(dsl::range(rank..=rank, dsl::true_()), selection); - } - selection - } - } -} - -#[async_trait] -#[hyperactor::forward(ControllerMessage)] -impl ControllerMessageHandler for ControllerActor { - async fn attach( - &mut self, - cx: &Context, - client_actor: ActorRef, - ) -> Result<(), anyhow::Error> { - tracing::debug!("attaching client actor {}", client_actor); - self.client_actor_ref - .set(client_actor) - .map_err(|actor_ref| anyhow::anyhow!("client actor {} already attached", actor_ref))?; - - // Trigger periodical checking of supervision status and worker progress. - cx.self_message_with_delay( - ControllerMessage::CheckSupervision {}, - self.supervision_query_interval, - )?; - cx.self_message_with_delay(CheckWorkerProgress, self.worker_progress_check_interval)?; - Ok(()) - } - - async fn node( - &mut self, - cx: &Context, - seq: Seq, - defs: Vec, - uses: Vec, - ) -> Result<(), anyhow::Error> { - let failures = self.history.add_invocation(seq, uses, defs); - let client = self.client()?; - for (seq, failure) in failures { - let _ = client.result(cx, seq, failure).await; - } - self.request_status_if_needed(cx).await?; - - Ok(()) - } - - async fn drop_refs( - &mut self, - _cx: &Context, - refs: Vec, - ) -> Result<(), anyhow::Error> { - self.history.delete_invocations_for_refs(refs); - Ok(()) - } - - async fn send( - &mut self, - cx: &Context, - ranks: Ranks, - message: Serialized, - ) -> Result<(), anyhow::Error> { - let selection = match ranks { - Ranks::Slice(slice) => { - if slice.len() == self.world_size { - // All ranks are selected. - Selection::True - } else { - slice_to_selection(slice) - } - } - Ranks::SliceList(slices) => slices.into_iter().fold(dsl::false_(), |sel, slice| { - dsl::union(sel, slice_to_selection(slice)) - }), - }; - - let slice = Slice::new(0usize, vec![self.world_size], vec![1])?; - // Use a made-up label to create a fake shape. This shape is used by - // comm actor to determine the cast rank. Cast rank is not used by - // DeviceMesh, but we still need a shape there to make the logic happy. - let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())? - .reshape(Limit::from(CASTING_FANOUT_SIZE)) - .shape; - - let message = CastMessageEnvelope::from_serialized( - ActorMeshId::V0( - ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()), - self.worker_gang_ref.gang_id().name().to_string(), - ), - cx.self_id().clone(), - DestinationPort::new::( - // This is awkward, but goes away entirely with meshes. - self.worker_gang_ref - .gang_id() - .actor_id(0) - .name() - .to_string(), - ), - made_up_shape, - message, - ); - - self.comm_actor_ref.send( - cx, - CastMessage { - dest: Uslice { - // TODO: pass both slice and selection from client side - slice, - selection, - }, - message, - }, - )?; - Ok(()) - } - - async fn remote_function_failed( - &mut self, - cx: &Context, - seq: Seq, - error: WorkerError, - ) -> Result<(), anyhow::Error> { - let rank = error.worker_actor_id.rank(); - self.history - .propagate_exception(seq, Exception::Error(seq, seq, error.clone())); - mark_worker_complete_and_propagate_exceptions(cx, self, rank, &seq).await?; - Ok(()) - } - - async fn status( - &mut self, - cx: &Context, - seq: Seq, - worker_actor_id: ActorId, - controller: bool, - ) -> Result<(), anyhow::Error> { - let rank = worker_actor_id.rank(); - - if controller { - self.history.update_deadline_tracking(rank, seq); - } else { - mark_worker_complete_and_propagate_exceptions(cx, self, rank, &seq).await?; - } - Ok(()) - } - - async fn fetch_result( - &mut self, - _cx: &Context, - seq: Seq, - result: Result, - ) -> Result<(), anyhow::Error> { - self.history.set_result(seq, result); - Ok(()) - } - - async fn check_supervision(&mut self, cx: &Context) -> Result<(), anyhow::Error> { - let gang_id: GangId = self.worker_gang_ref.clone().into(); - let world_state = self - .system_supervision_actor_ref - .state(cx, gang_id.world_id().clone()) - .await?; - - if let Some(world_state) = world_state { - if !world_state.procs.is_empty() { - tracing::error!( - "found procs with failures in world {}, state: {:?}", - gang_id.world_id(), - world_state - ); - - // Randomly pick a failed proc as the failed actor. - let (_, failed_state) = world_state.procs.iter().next().unwrap(); - let (failed_actor, failure_reason) = - failed_state.failed_actors.first().map_or_else( - || { - let proc_id = &failed_state.proc_id; - ( - ActorId(proc_id.clone(), "none".into(), 0), - format!( - "proc is dead due to heartbeat timeout; no backtrace is \ - available; check the log of host {} running proc {} to \ - figure out the root cause", - failed_state.proc_addr, proc_id - ), - ) - }, - |(actor, status)| { - ( - actor.clone(), - match status { - ActorStatus::Failed(msg) => msg.to_string(), - _ => format!("unexpected actor status {status}"), - }, - ) - }, - ); - - let exc = Exception::Failure(DeviceFailure { - actor_id: failed_actor, - address: failed_state.proc_addr.to_string(), - backtrace: failure_reason, - }); - tracing::error!("Sending failure to client: {exc:?}"); - // Seq does not matter as the client will raise device error immediately before setting the results. - self.client()? - .result(cx, Seq::default(), Some(Err(exc))) - .await?; - tracing::error!("Failure successfully sent to client"); - - // No need to set history failures as we are directly sending back failure results. - } - } - - // Schedule the next supervision check. - cx.self_message_with_delay( - ControllerMessage::CheckSupervision {}, - self.supervision_query_interval, - )?; - Ok(()) - } - - async fn debugger_message( - &mut self, - cx: &Context, - debugger_actor_id: ActorId, - action: DebuggerAction, - ) -> Result<(), anyhow::Error> { - self.client()? - .debugger_message(cx, debugger_actor_id, action) - .await - } - - #[cfg(test)] - async fn get_first_incomplete_seqs_unit_tests_only( - &mut self, - _cx: &Context, - ) -> Result, anyhow::Error> { - Ok(self.history.first_incomplete_seqs().to_vec()) - } - - #[cfg(not(test))] - async fn get_first_incomplete_seqs_unit_tests_only( - &mut self, - _cx: &Context, - ) -> Result, anyhow::Error> { - unimplemented!("get_first_incomplete_seqs_unit_tests_only is only for unit tests") - } -} - -async fn mark_worker_complete_and_propagate_exceptions( - cx: &impl context::Actor, - actor: &mut ControllerActor, - rank: usize, - seq: &Seq, -) -> Result<(), anyhow::Error> { - let results = actor.history.rank_completed(rank, seq.clone()); - let client = actor.client()?; - // Propagate the failures to the clients. - for (seq, result) in results.iter() { - let _ = client.result(cx, seq.clone(), result.clone()).await; - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use core::panic; - use std::assert_matches::assert_matches; - use std::collections::HashMap; - use std::collections::HashSet; - use std::time::Duration; - - use hyperactor::HandleClient; - use hyperactor::Handler; - use hyperactor::RefClient; - use hyperactor::channel; - use hyperactor::channel::ChannelTransport; - use hyperactor::clock::Clock; - use hyperactor::clock::RealClock; - use hyperactor::context::Mailbox as _; - use hyperactor::id; - use hyperactor::mailbox::BoxedMailboxSender; - use hyperactor::mailbox::DialMailboxRouter; - use hyperactor::mailbox::Mailbox; - use hyperactor::mailbox::MailboxClient; - use hyperactor::mailbox::MailboxServer; - use hyperactor::mailbox::PortHandle; - use hyperactor::mailbox::PortReceiver; - use hyperactor::message::IndexedErasedUnbound; - use hyperactor::proc::Proc; - use hyperactor::reference::GangId; - use hyperactor::reference::ProcId; - use hyperactor::reference::WorldId; - use hyperactor::simnet; - use hyperactor_mesh::comm::CommActorParams; - use hyperactor_multiprocess::System; - use hyperactor_multiprocess::proc_actor::ProcMessage; - use hyperactor_multiprocess::supervision::ProcSupervisionMessage; - use hyperactor_multiprocess::supervision::ProcSupervisor; - use hyperactor_multiprocess::system_actor::SystemMessage; - use monarch_messages::client::ClientMessage; - use monarch_messages::controller::ControllerMessageClient; - use monarch_messages::wire_value::WireValue; - use monarch_messages::worker::CallFunctionParams; - use monarch_messages::worker::WorkerMessage; - use monarch_types::PyTree; - use timed_test::async_timed_test; - use torch_sys::RValue; - - use super::*; - - #[tokio::test] - async fn basic_controller() { - // TODO: Add a proper multiworker test - let proc = Proc::local(); - let (client, client_ref, mut client_rx) = proc - .attach_actor::("client") - .unwrap(); - let (worker, worker_ref, mut worker_rx) = proc - .attach_actor::("worker") - .unwrap(); - - IndexedErasedUnbound::::bind_for_test_only( - worker_ref.clone(), - worker.clone_for_py(), - worker.mailbox().clone(), - ) - .unwrap(); - - let comm_handle = proc - .spawn::("comm", CommActorParams {}) - .await - .unwrap(); - - let controller_handle = proc - .spawn::( - "controller", - ControllerParams { - world_size: 1, - comm_actor_ref: comm_handle.bind(), - worker_gang_ref: GangRef::attest(GangId( - WorldId( - proc.proc_id() - .world_name() - .expect("only ranked actors are supported in the controller tests") - .to_string(), - ), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(1), - worker_progress_check_interval: Duration::from_secs(3), - operation_timeout: Duration::from_secs(30), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - ) - .await - .unwrap(); - - controller_handle.attach(&client, client_ref).await.unwrap(); - - controller_handle - .node(&client, 0.into(), vec![0.into()], vec![]) - .await - .unwrap(); - controller_handle - .node(&client, 1.into(), vec![1.into(), 2.into()], vec![0.into()]) - .await - .unwrap(); - controller_handle - .node(&client, 20.into(), vec![3.into(), 4.into()], vec![]) - .await - .unwrap(); - - ControllerMessageClient::send( - &controller_handle, - &worker, - Ranks::Slice(ndslice::Slice::new(0, vec![1], vec![1]).unwrap()), - Serialized::serialize(&WorkerMessage::CallFunction(CallFunctionParams { - seq: 1.into(), - results: vec![Some(1.into()), Some(2.into())], - mutates: vec![], - function: "os.path.split".into(), - args: vec![WireValue::String("/fbs/fbc/foo/bar".into())], - kwargs: HashMap::new(), - stream: 1.into(), - remote_process_groups: vec![], - })) - .unwrap(), - ) - .await - .unwrap(); - - ControllerMessageClient::status( - &controller_handle, - &worker, - 0.into(), - worker_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - let incomplete_seqs = controller_handle - .get_first_incomplete_seqs_unit_tests_only(&worker) - .await - .unwrap(); - assert_eq!(incomplete_seqs[0], 0.into()); - - controller_handle - .remote_function_failed( - &worker, - 1.into(), - WorkerError { - backtrace: "some failure happened!".to_string(), - worker_actor_id: worker_ref.actor_id().clone(), - }, - ) - .await - .unwrap(); - ControllerMessageClient::status( - &controller_handle, - &worker, - 2.into(), - worker_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - - let incomplete_seqs = controller_handle - .get_first_incomplete_seqs_unit_tests_only(&worker) - .await - .unwrap(); - assert_eq!(incomplete_seqs[0], 2.into()); - - controller_handle - .fetch_result( - &worker, - 20.into(), - Ok(Serialized::serialize(&PyTree::from(RValue::Int(42))).unwrap()), - ) - .await - .unwrap(); - - // Omly a status message can trigger a fetch result to the client. - ControllerMessageClient::status( - &controller_handle, - &worker, - 21.into(), - worker_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - - let incomplete_seqs = controller_handle - .get_first_incomplete_seqs_unit_tests_only(&worker) - .await - .unwrap(); - assert_eq!(incomplete_seqs[0], 21.into()); - - controller_handle.drain_and_stop().unwrap(); - controller_handle.await; - let worker_messages: Vec = worker_rx.drain(); - assert_eq!( - worker_messages - .iter() - .filter(|msg| !matches!(msg, WorkerMessage::RequestStatus { .. })) - .count(), - 1 - ); - let client_messages = client_rx.drain(); - assert_eq!(client_messages.len(), 3); - let client_message = client_messages[1].clone().into_result().unwrap(); - assert_eq!(client_message.0, 1.into()); - assert_eq!( - client_message.1, - Some(Err(Exception::Error( - 1.into(), - 1.into(), - WorkerError { - backtrace: "some failure happened!".to_string(), - worker_actor_id: worker_ref.actor_id().clone(), - } - ))) - ); - - let client_message = client_messages[2].clone().into_result().unwrap(); - assert_eq!(client_message.0, 20.into()); - assert_matches!( - client_message - .1 - .unwrap() - .unwrap() - .deserialized::>() - .unwrap() - .into_leaf() - .unwrap(), - RValue::Int(42), - ); - } - - #[tokio::test] - async fn worker_timeout() { - tokio::time::pause(); - let timeout_secs = 3; - let proc = Proc::local(); - - let (client, client_ref, mut client_rx) = proc - .attach_actor::("client") - .unwrap(); - let (worker, worker_ref, mut worker_rx) = proc - .attach_actor::("worker") - .unwrap(); - IndexedErasedUnbound::::bind_for_test_only( - worker_ref.clone(), - worker.clone_for_py(), - worker.mailbox().clone(), - ) - .unwrap(); - - let comm_handle = proc - .spawn::("comm", CommActorParams {}) - .await - .unwrap(); - - let controller_handle = proc - .spawn::( - "controller", - ControllerParams { - world_size: 1, - comm_actor_ref: comm_handle.bind(), - worker_gang_ref: GangRef::attest(GangId( - WorldId( - proc.proc_id() - .world_name() - .expect("only ranked actors are supported in the controller tests") - .to_string(), - ), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(100000), - worker_progress_check_interval: Duration::from_secs(1), - operation_timeout: Duration::from_secs(timeout_secs), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - ) - .await - .unwrap(); - - controller_handle.attach(&client, client_ref).await.unwrap(); - - controller_handle - .node(&client, 0.into(), vec![0.into()], vec![]) - .await - .unwrap(); - - // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq - match worker_rx.recv().await.unwrap().into_request_status().ok() { - Some((seq, controller)) if seq == 0.into() && controller => { - // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes shorter - // than timeout - for _ in 0..timeout_secs { - tokio::time::advance(Duration::from_secs(1)).await; - } - - ControllerMessageClient::status( - &controller_handle, - &worker, - 1.into(), - worker_ref.actor_id().clone(), - true, - ) - .await - .unwrap(); - } - _ => panic!("Expected request status message for seq 0"), - } - - // Should have no warnings - let client_messages = client_rx.drain(); - assert_eq!(client_messages.len(), 0); - - controller_handle - .node(&client, 1.into(), vec![], vec![]) - .await - .unwrap(); - - // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq - match worker_rx.recv().await.unwrap().into_request_status().ok() { - Some((seq, controller)) if seq == 1.into() && controller => { - // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes longer - // than timeout - for _ in 0..timeout_secs * 2 { - tokio::time::advance(Duration::from_secs(1)).await; - } - - ControllerMessageClient::status( - &controller_handle, - &worker, - 2.into(), - worker_ref.actor_id().clone(), - true, - ) - .await - .unwrap(); - } - _ => panic!("Expected request status message for seq 1"), - } - - let client_messages = client_rx.drain(); - assert_eq!(client_messages.len(), 1); - - let (level, message) = client_messages[0].clone().into_log().unwrap(); - assert_matches!(level, LogLevel::Warn); - assert_eq!( - message, - "ranks 0 have operations that have not completed after 3 seconds" - ); - } - - #[tokio::test] - async fn test_failure_on_worker_timeout() { - tokio::time::pause(); - let timeout_secs = 3; - let proc = Proc::local(); - - let (client, client_ref, mut client_rx) = proc - .attach_actor::("client") - .unwrap(); - - let (worker, worker_ref, mut worker_rx) = proc - .attach_actor::("worker") - .unwrap(); - IndexedErasedUnbound::::bind_for_test_only( - worker_ref.clone(), - worker.clone_for_py(), - worker.mailbox().clone(), - ) - .unwrap(); - - let comm_handle = proc - .spawn::("comm", CommActorParams {}) - .await - .unwrap(); - - let world_id = WorldId( - proc.proc_id() - .world_name() - .expect("only ranked actors are supported in the controller tests") - .to_string(), - ); - let controller_handle = proc - .spawn::( - "controller", - ControllerParams { - world_size: 1, - comm_actor_ref: comm_handle.bind(), - worker_gang_ref: GangRef::attest(GangId(world_id, "worker".to_string())), - supervision_query_interval: Duration::from_secs(100000), - worker_progress_check_interval: Duration::from_secs(1), - operation_timeout: Duration::from_secs(timeout_secs), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: true, - }, - ) - .await - .unwrap(); - - controller_handle.attach(&client, client_ref).await.unwrap(); - - controller_handle - .node(&client, 0.into(), vec![0.into()], vec![]) - .await - .unwrap(); - - // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq - match worker_rx.recv().await.unwrap().into_request_status().ok() { - Some((seq, controller)) if seq == 0.into() && controller => { - // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes shorter - // than timeout - for _ in 0..timeout_secs { - tokio::time::advance(Duration::from_secs(1)).await; - } - - ControllerMessageClient::status( - &controller_handle, - &worker, - 1.into(), - worker_ref.actor_id().clone(), - true, - ) - .await - .unwrap(); - } - _ => panic!("Expected request status message for seq 0"), - } - - // Should have no warnings - let client_messages = client_rx.drain(); - assert_eq!(client_messages.len(), 0); - - controller_handle - .node(&client, 1.into(), vec![], vec![]) - .await - .unwrap(); - - // Expect that our handler for CheckWorkerProgress will issue RequestWorkerCompletedSeq - match worker_rx.recv().await.unwrap().into_request_status().ok() { - Some((seq, controller)) if seq == 1.into() && controller => { - // Simulate WorkerActor::RequestWorkerCompletedSeq if joining streams takes longer - // than timeout - for _ in 0..timeout_secs * 2 { - tokio::time::advance(Duration::from_secs(1)).await; - } - - ControllerMessageClient::status( - &controller_handle, - &worker, - 2.into(), - worker_ref.actor_id().clone(), - true, - ) - .await - .unwrap(); - } - _ => panic!("Expected request status message for seq 1"), - } - - let client_messages = client_rx.drain(); - assert_eq!(client_messages.len(), 2); - - let (level, message) = client_messages[0].clone().into_log().unwrap(); - assert_matches!(level, LogLevel::Warn); - assert_eq!( - message, - "ranks 0 have operations that have not completed after 3 seconds" - ); - - let (seq, failure) = client_messages[1].clone().into_result().unwrap(); - assert_eq!(seq, 1.into()); - let DeviceFailure { - backtrace, - actor_id, - .. - } = failure - .unwrap() - .err() - .unwrap() - .as_failure() - .unwrap() - .clone(); - assert_eq!(actor_id, proc.proc_id().actor_id("worker", 0)); - assert!( - backtrace.contains("ranks 0 have operations that have not completed after 3 seconds") - ); - } - - #[tokio::test] - async fn failure_propagation() { - // Serve a system. - let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Local), - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - .unwrap(); - let mut system = System::new(server_handle.local_addr().clone()); - - // Build a supervisor. - let sup_mail = system.attach().await.unwrap(); - let (_sup_tx, _sup_rx) = sup_mail.bind_actor_port::(); - let sup_ref = ActorRef::::attest(sup_mail.self_id().clone()); - - // Construct a system sender. - let system_sender = BoxedMailboxSender::new(MailboxClient::new( - channel::dial(server_handle.local_addr().clone()).unwrap(), - )); - - // Construct a proc forwarder in terms of the system sender. - let listen_addr = ChannelAddr::any(ChannelTransport::Local); - let proc_forwarder = - BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender)); - - // Bootstrap proc 'local[0]', join the system. - let world_id = id!(local); - let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone()); - let proc_actor_0 = ProcActor::bootstrap_for_proc( - proc.clone(), - world_id.clone(), - listen_addr, - server_handle.local_addr().clone(), - sup_ref.clone(), - Duration::from_secs(2), - HashMap::new(), - ProcLifecycleMode::ManagedBySystem, - ) - .await - .unwrap(); - - // Bootstrap proc 'local[1]', join the system. - let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone()); - let _proc_actor_1 = ProcActor::bootstrap_for_proc( - proc2.clone(), - world_id.clone(), - ChannelAddr::any(ChannelTransport::Local), - server_handle.local_addr().clone(), - sup_ref.clone(), - Duration::from_secs(2), - HashMap::new(), - ProcLifecycleMode::ManagedBySystem, - ) - .await - .unwrap(); - - // Test - let (client, client_ref, mut client_rx) = proc - .attach_actor::("client") - .unwrap(); - let (worker1, worker1_ref, _) = proc - .attach_actor::("worker") - .unwrap(); - IndexedErasedUnbound::::bind_for_test_only( - worker1_ref.clone(), - worker1.clone_for_py(), - worker1.mailbox().clone(), - ) - .unwrap(); - let (worker2, worker2_ref, _) = proc2 - .attach_actor::("worker") - .unwrap(); - IndexedErasedUnbound::::bind_for_test_only( - worker2_ref.clone(), - worker2.clone_for_py(), - worker2.mailbox().clone(), - ) - .unwrap(); - - let controller_handle = proc - .spawn::( - "controller", - ControllerParams { - world_size: 2, - comm_actor_ref: proc_actor_0.comm_actor.bind(), - worker_gang_ref: GangRef::attest(GangId( - WorldId(world_id.name().to_string()), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(1), - worker_progress_check_interval: Duration::from_secs(3), - operation_timeout: Duration::from_secs(30), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - ) - .await - .unwrap(); - - controller_handle.attach(&client, client_ref).await.unwrap(); - - controller_handle - .node(&client, 0.into(), vec![1.into(), 2.into()], vec![]) - .await - .unwrap(); - controller_handle - .node(&client, 1.into(), vec![3.into()], vec![1.into()]) - .await - .unwrap(); - controller_handle - .node(&client, 2.into(), vec![4.into()], vec![3.into()]) - .await - .unwrap(); - controller_handle - .node(&client, 3.into(), vec![5.into()], vec![3.into()]) - .await - .unwrap(); - controller_handle - .node(&client, 4.into(), vec![6.into()], vec![3.into()]) - .await - .unwrap(); - controller_handle - .node(&client, 5.into(), vec![7.into()], vec![4.into()]) - .await - .unwrap(); - controller_handle - .node(&client, 6.into(), vec![8.into()], vec![4.into()]) - .await - .unwrap(); - - ControllerMessageClient::status( - &controller_handle, - &worker1, - 1.into(), - worker1_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - ControllerMessageClient::status( - &controller_handle, - &worker2, - 1.into(), - worker2_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - controller_handle - .remote_function_failed( - &worker1, - 2.into(), - WorkerError { - backtrace: "some failure happened!".to_string(), - worker_actor_id: worker1_ref.actor_id().clone(), - }, - ) - .await - .unwrap(); - controller_handle - .remote_function_failed( - &worker2, - 2.into(), - WorkerError { - backtrace: "some failure happened!".to_string(), - worker_actor_id: worker2_ref.actor_id().clone(), - }, - ) - .await - .unwrap(); - for s in 3..=7 { - ControllerMessageClient::status( - &controller_handle, - &worker1, - s.into(), - worker1_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - ControllerMessageClient::status( - &controller_handle, - &worker2, - s.into(), - worker2_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - } - - controller_handle.drain_and_stop().unwrap(); - controller_handle.await; - let mut client_messages = client_rx.drain(); - client_messages.sort_by_key(|msg| msg.clone().into_result().unwrap().0); - assert_eq!(client_messages.len(), 7); - let client_message = client_messages[2].clone().into_result().unwrap(); - assert_eq!(client_message.0, 2.into()); - assert_eq!( - client_message.1, - Some(Err(Exception::Error( - 2.into(), - 2.into(), - WorkerError { - backtrace: "some failure happened!".to_string(), - worker_actor_id: worker1_ref.actor_id().clone(), - } - ))) - ); - - assert_eq!( - client_messages - .into_iter() - .map(|msg| msg.into_result().unwrap().0) - .collect::>(), - HashSet::from([ - 0.into(), - 3.into(), - 1.into(), - 4.into(), - 2.into(), - 5.into(), - 6.into() - ]) - ) - } - - #[tokio::test] - async fn test_eager_failure_reporting() { - // Serve a system. - let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Local), - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - .unwrap(); - let mut system = System::new(server_handle.local_addr().clone()); - - // Build a supervisor. - let sup_mail = system.attach().await.unwrap(); - let (_sup_tx, _sup_rx) = sup_mail.bind_actor_port::(); - let sup_ref = ActorRef::::attest(sup_mail.self_id().clone()); - - // Construct a system sender. - let system_sender = BoxedMailboxSender::new(MailboxClient::new( - channel::dial(server_handle.local_addr().clone()).unwrap(), - )); - - // Construct a proc forwarder in terms of the system sender. - let listen_addr = ChannelAddr::any(ChannelTransport::Local); - let proc_forwarder = - BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender)); - - // Bootstrap proc 'local[0]', join the system. - let world_id = id!(local); - let proc = Proc::new(world_id.proc_id(0), proc_forwarder.clone()); - let proc_actor_0 = ProcActor::bootstrap_for_proc( - proc.clone(), - world_id.clone(), - listen_addr, - server_handle.local_addr().clone(), - sup_ref.clone(), - Duration::from_secs(2), - HashMap::new(), - ProcLifecycleMode::ManagedBySystem, - ) - .await - .unwrap(); - - // Bootstrap proc 'local[1]', join the system. - let proc2 = Proc::new(world_id.proc_id(1), proc_forwarder.clone()); - let _proc_actor_1 = ProcActor::bootstrap_for_proc( - proc2.clone(), - world_id.clone(), - ChannelAddr::any(ChannelTransport::Local), - server_handle.local_addr().clone(), - sup_ref.clone(), - Duration::from_secs(2), - HashMap::new(), - ProcLifecycleMode::ManagedBySystem, - ) - .await - .unwrap(); - - // Test - let (client, client_ref, mut client_rx) = proc - .attach_actor::("client") - .unwrap(); - let (worker1, worker1_ref, _) = proc - .attach_actor::("worker") - .unwrap(); - - let controller_handle = proc - .spawn::( - "controller", - ControllerParams { - world_size: 1, - comm_actor_ref: proc_actor_0.comm_actor.bind(), - worker_gang_ref: GangRef::attest(GangId( - WorldId(world_id.name().to_string()), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(1), - worker_progress_check_interval: Duration::from_secs(3), - operation_timeout: Duration::from_secs(30), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - ) - .await - .unwrap(); - - controller_handle.attach(&client, client_ref).await.unwrap(); - - controller_handle - .node(&client, 0.into(), vec![1.into()], vec![]) - .await - .unwrap(); - - controller_handle - .node(&client, 1.into(), vec![2.into()], vec![1.into()]) - .await - .unwrap(); - - controller_handle - .node(&client, 2.into(), vec![3.into()], vec![2.into()]) - .await - .unwrap(); - - controller_handle - .node(&client, 3.into(), vec![], vec![3.into()]) - .await - .unwrap(); - - controller_handle - .node(&client, 4.into(), vec![], vec![]) - .await - .unwrap(); - - controller_handle - .remote_function_failed( - &worker1, - 0.into(), - WorkerError { - backtrace: "some failure happened!".to_string(), - worker_actor_id: worker1_ref.actor_id().clone(), - }, - ) - .await - .unwrap(); - - controller_handle - .remote_function_failed( - &worker1, - 3.into(), - WorkerError { - backtrace: "some failure happened!".to_string(), - worker_actor_id: worker1_ref.actor_id().clone(), - }, - ) - .await - .unwrap(); - - ControllerMessageClient::status( - &controller_handle, - &worker1, - 5.into(), - worker1_ref.actor_id().clone(), - false, - ) - .await - .unwrap(); - - controller_handle.drain_and_stop().unwrap(); - controller_handle.await; - - let client_messages = client_rx.drain(); - // no double reported messages - assert_eq!(client_messages.len(), 5); - - let (errors, successes) = - client_messages - .into_iter() - .fold((0, 0), |(errors, successes), client_message| { - let (_, result) = client_message.clone().into_result().unwrap(); - match result { - Some(Err(Exception::Error(_, _, _))) => (errors + 1, successes), - None => (errors, successes + 1), - _ => { - panic!("should only be exceptions or no result"); - } - } - }); - - // Assert that we have 4 error messages and 1 non-error message - assert_eq!(errors, 4); - assert_eq!(successes, 1); - } - - #[tokio::test] - async fn test_bootstrap() { - let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Local), - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - .unwrap(); - - let controller_id = id!(controller[0].root); - let proc_id = id!(world[0]); - let (proc_handle, actor_ref) = ControllerActor::bootstrap( - controller_id.clone(), - ChannelAddr::any(ChannelTransport::Local), - server_handle.local_addr().clone(), - ControllerParams { - world_size: 1, - comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), - worker_gang_ref: GangRef::attest(GangId( - WorldId( - proc_id - .world_name() - .expect("only ranked actors are supported in the controller tests") - .to_string(), - ), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(1), - worker_progress_check_interval: Duration::from_secs(3), - operation_timeout: Duration::from_secs(30), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - Duration::from_secs(1), - HashMap::new(), - ) - .await - .unwrap(); - assert_eq!(*actor_ref.actor_id(), controller_id); - - proc_handle.drain_and_stop().unwrap(); - } - - async fn mock_proc_actor( - idx: usize, - rank: usize, - ) -> ( - WorldId, - ProcId, - ChannelAddr, - Mailbox, - PortHandle, - PortReceiver, - ) { - let world_id = id!(world); - // Set up a local actor. - let local_proc_id = world_id.proc_id(rank); - let (local_proc_addr, local_proc_rx) = - channel::serve(ChannelAddr::any(ChannelTransport::Local)).unwrap(); - let local_proc_mbox = Mailbox::new_detached( - local_proc_id.actor_id(format!("test_dummy_proc{}", idx).to_string(), 0), - ); - let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port(); - local_proc_message_port.bind(); - - let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx); - ( - world_id, - local_proc_id, - local_proc_addr, - local_proc_mbox, - local_proc_message_port, - local_proc_message_receiver, - ) - } - - #[tokio::test] - async fn test_sim_supervision_failure() { - // Start system actor. - simnet::start(); - simnet::simnet_handle() - .unwrap() - .set_training_script_state(simnet::TrainingScriptState::Waiting); - - let system_sim_addr = - ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); - // Set very long supervision_update_timeout - let server_handle = System::serve( - system_sim_addr.clone(), - Duration::from_secs(1000), - Duration::from_secs(1000), - ) - .await - .unwrap(); - - let mut system = System::new(server_handle.local_addr().clone()); - let client_mailbox = system.attach().await.unwrap(); - - // Bootstrap the controller - let controller_id = id!(controller[0].root); - let proc_id = id!(world[0]); - let controller_proc_listen_addr = - ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); - - let (_, actor_ref) = ControllerActor::bootstrap( - controller_id.clone(), - controller_proc_listen_addr, - system_sim_addr, - ControllerParams { - world_size: 1, - comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), - worker_gang_ref: GangRef::attest(GangId( - WorldId( - proc_id - .world_name() - .expect("only ranked actors are supported in the controller tests") - .to_string(), - ), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(100), - worker_progress_check_interval: Duration::from_secs(100), - operation_timeout: Duration::from_secs(1000), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - Duration::from_secs(100), - HashMap::new(), - ) - .await - .unwrap(); - assert_eq!(*actor_ref.actor_id(), controller_id); - - actor_ref - .attach( - &client_mailbox, - ActorRef::attest(client_mailbox.self_id().clone()), - ) - .await - .unwrap(); - - let (_client_supervision_tx, mut client_supervision_rx) = - client_mailbox.bind_actor_port::(); - - // mock a proc actor that doesn't update supervision state - let ( - world_id, - local_proc_id, - local_proc_addr, - _, - local_proc_message_port, - mut local_proc_message_receiver, - ) = mock_proc_actor(0, 1).await; - - // Join the world. - server_handle - .system_actor_handle() - .send(SystemMessage::Join { - proc_id: local_proc_id.clone(), - world_id, - proc_message_port: local_proc_message_port.bind(), - proc_addr: local_proc_addr, - labels: HashMap::new(), - lifecycle_mode: ProcLifecycleMode::ManagedBySystem, - }) - .unwrap(); - - assert_matches!( - local_proc_message_receiver.recv().await.unwrap(), - ProcMessage::Joined() - ); - - // expect that supervision timeout which takes 1000 real seconds is hit super quickly - // due to simulated time - let result = client_supervision_rx - .recv() - .await - .unwrap() - .into_result() - .unwrap(); - assert_eq!(result.0, Seq::default()); - assert!(result.1.expect("result").is_err()); - - let records = simnet::simnet_handle().unwrap().close().await.unwrap(); - eprintln!("{}", serde_json::to_string_pretty(&records).unwrap()); - } - #[tokio::test] - async fn test_supervision_failure() { - // Start system actor. - let timeout: Duration = Duration::from_secs(6); - let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Local), - timeout.clone(), - timeout.clone(), - ) - .await - .unwrap(); - - // Client actor. - let mut system = System::new(server_handle.local_addr().clone()); - let client_mailbox = system.attach().await.unwrap(); - let (_client_supervision_tx, mut client_supervision_rx) = - client_mailbox.bind_actor_port::(); - - // Bootstrap the controller - let controller_id = id!(controller[0].root); - let proc_id = id!(world[0]); - let (_, actor_ref) = ControllerActor::bootstrap( - controller_id.clone(), - ChannelAddr::any(ChannelTransport::Local), - server_handle.local_addr().clone(), - ControllerParams { - world_size: 1, - comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), - worker_gang_ref: GangRef::attest(GangId( - WorldId( - proc_id - .world_name() - .expect("only ranked actors are supported in the controller tests") - .to_string(), - ), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(1), - worker_progress_check_interval: Duration::from_secs(3), - operation_timeout: Duration::from_secs(30), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - Duration::from_secs(1), - HashMap::new(), - ) - .await - .unwrap(); - assert_eq!(*actor_ref.actor_id(), controller_id); - - actor_ref - .attach( - &client_mailbox, - ActorRef::attest(client_mailbox.self_id().clone()), - ) - .await - .unwrap(); - - // mock a proc actor that doesn't update supervision state - let ( - world_id, - local_proc_id, - local_proc_addr, - _, - local_proc_message_port, - mut local_proc_message_receiver, - ) = mock_proc_actor(0, 1).await; - - // Join the world. - server_handle - .system_actor_handle() - .send(SystemMessage::Join { - proc_id: local_proc_id.clone(), - world_id, - proc_message_port: local_proc_message_port.bind(), - proc_addr: local_proc_addr, - labels: HashMap::new(), - lifecycle_mode: ProcLifecycleMode::ManagedBySystem, - }) - .unwrap(); - - assert_matches!( - local_proc_message_receiver.recv().await.unwrap(), - ProcMessage::Joined() - ); - - // Wait a bit; supervision update should time out. - RealClock.sleep(2 * timeout.clone()).await; - - // Should've gotten the supervision message indicating supervision failure - let result = client_supervision_rx - .recv() - .await - .unwrap() - .into_result() - .unwrap(); - assert_eq!(result.0, Seq::default()); - assert!(result.1.expect("result").is_err()); - } - - #[derive( - Handler, - HandleClient, - RefClient, - Named, - Debug, - Clone, - Serialize, - Deserialize, - PartialEq - )] - enum PanickingMessage { - Panic(String), - } - - #[derive(Debug, Default, Actor)] - #[hyperactor::export( - handlers = [ - PanickingMessage, - ], - )] - struct PanickingActor; - - #[async_trait] - #[hyperactor::forward(PanickingMessage)] - impl PanickingMessageHandler for PanickingActor { - async fn panic( - &mut self, - _cx: &Context, - err_msg: String, - ) -> Result<(), anyhow::Error> { - panic!("{}", err_msg); - } - } - - hyperactor::remote!(PanickingActor); - - #[async_timed_test(timeout_secs = 30)] - // times out (both internal and external). - #[cfg_attr(not(fbcode_build), ignore)] - async fn test_supervision_fault() { - // Start system actor. - let timeout: Duration = Duration::from_secs(6); - let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Local), - timeout.clone(), - timeout.clone(), - ) - .await - .unwrap(); - - // Client actor. - let mut system = System::new(server_handle.local_addr().clone()); - let client_mailbox = system.attach().await.unwrap(); - let (_client_supervision_tx, mut client_supervision_rx) = - client_mailbox.bind_actor_port::(); - - // Bootstrap the controller - let controller_id = id!(controller[0].root); - let proc_id = id!(world[0]); - let (_, actor_ref) = ControllerActor::bootstrap( - controller_id.clone(), - ChannelAddr::any(ChannelTransport::Local), - server_handle.local_addr().clone(), - ControllerParams { - world_size: 1, - comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), - worker_gang_ref: GangRef::attest(GangId( - WorldId( - proc_id - .world_name() - .expect("only ranked actors are supported in the controller tests") - .to_string(), - ), - "worker".to_string(), - )), - supervision_query_interval: Duration::from_secs(1), - worker_progress_check_interval: Duration::from_secs(3), - operation_timeout: Duration::from_secs(30), - operations_per_worker_progress_request: 100, - fail_on_worker_timeout: false, - }, - Duration::from_secs(1), - HashMap::new(), - ) - .await - .unwrap(); - assert_eq!(*actor_ref.actor_id(), controller_id); - - actor_ref - .attach( - &client_mailbox, - ActorRef::attest(client_mailbox.self_id().clone()), - ) - .await - .unwrap(); - - // bootstreap an actor that panics - let world_id = id!(world); - let panic_proc_id = world_id.proc_id(1); - let bootstrap = ProcActor::bootstrap( - panic_proc_id, - world_id, - ChannelAddr::any(ChannelTransport::Local), - server_handle.local_addr().clone(), - Duration::from_secs(3), - HashMap::new(), - ProcLifecycleMode::ManagedBySystem, - ) - .await - .unwrap(); - let actor_handle = spawn::( - &client_mailbox, - &bootstrap.proc_actor.bind(), - "panicker", - &(), - ) - .await - .unwrap(); - - actor_handle - .panic(&client_mailbox, "some random failure".to_string()) - .await - .unwrap(); - - // Get the supervision message with the panic - let result = client_supervision_rx - .recv() - .await - .unwrap() - .into_result() - .unwrap(); - assert_eq!(result.0, Seq::default()); - assert!(result.1.is_some() && result.1.as_ref().unwrap().is_err()); - let Exception::Failure(err) = result.1.unwrap().unwrap_err() else { - panic!("Expected Failure exception"); - }; - assert!(err.backtrace.contains("some random failure")); - } -} diff --git a/controller/src/main.rs b/controller/src/main.rs deleted file mode 100644 index 818c7d945..000000000 --- a/controller/src/main.rs +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! A binary to launch the system, host, or controller actors. Controller actors should be -//! launched through a separate binary that is defined in crate [`monarch_tensor_worker`] -//! due to Python dependency for workers. - -use std::os::fd::FromRawFd; -use std::os::fd::RawFd; - -use anyhow::Result; -use clap::Parser; -use controller::bootstrap::ControllerServerRequest; -use controller::bootstrap::ControllerServerResponse; -use controller::bootstrap::RunCommand; -use tokio::io::AsyncBufRead; -use tokio::io::AsyncBufReadExt; -use tokio::io::AsyncWrite; -use tokio::io::AsyncWriteExt; -use tokio::io::BufReader; - -/// Bootstrap commands and arguments used for system, proc, and controller actors. -// TODO: The logic to spawn the hyperactor part of this can probably live in hyperactor -// itself, and we can just call that from here. -#[derive(Parser)] -enum BootstrapCommand { - #[command(flatten)] - Run(RunCommand), - Serve { - read: RawFd, - write: RawFd, - }, -} - -async fn serve(inp: impl AsyncBufRead + Unpin, mut outp: impl AsyncWrite + Unpin) -> Result<()> { - tracing::info!("running controller server on {}", std::process::id()); - - let mut lines = inp.lines(); - while let Some(line) = lines.next_line().await? { - let request: ControllerServerRequest = serde_json::from_str(&line)?; - tracing::info!("got controller request: {:?}", request); - let response = match serde_json::from_str(&line)? { - ControllerServerRequest::Run(cmd) => { - let res = controller::bootstrap::run(cmd)?.await?; - ControllerServerResponse::Finished { - error: match res { - Err(err) => Some(format!("{}", err)), - Ok(()) => None, - }, - } - } - ControllerServerRequest::Exit() => break, - }; - tracing::info!("sending controller response: {:?}", response); - outp.write_all(format!("{}\n", serde_json::to_string(&response)?).as_bytes()) - .await?; - outp.flush().await?; - } - - tracing::info!("finished running controller server"); - - Ok(()) -} - -#[tokio::main] -async fn main() -> Result<()> { - hyperactor::initialize_with_current_runtime(); - - match BootstrapCommand::try_parse()? { - BootstrapCommand::Run(cmd) => controller::bootstrap::run(cmd)?.await??, - BootstrapCommand::Serve { read, write } => { - serve( - // SAFETY: Raw FD passed in from parent. - BufReader::new(unsafe { tokio::fs::File::from_raw_fd(read) }), - // SAFETY: Raw FD passed in from parent. - unsafe { tokio::fs::File::from_raw_fd(write) }, - ) - .await? - } - } - - Ok(()) -} diff --git a/monarch_extension/Cargo.toml b/monarch_extension/Cargo.toml index 7f7740d93..c57f6aecf 100644 --- a/monarch_extension/Cargo.toml +++ b/monarch_extension/Cargo.toml @@ -18,7 +18,6 @@ anyhow = "1.0.98" async-trait = "0.1.86" bincode = "1.3.3" clap = { version = "4.5.42", features = ["derive", "env", "string", "unicode", "wrap_help"] } -controller = { version = "0.0.0", path = "../controller", optional = true } futures = { version = "0.3.31", features = ["async-await", "compat"] } hyperactor = { version = "0.0.0", path = "../hyperactor" } hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" } @@ -29,7 +28,6 @@ libc = "0.2.139" monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" } monarch_messages = { version = "0.0.0", path = "../monarch_messages", optional = true } monarch_rdma_extension = { version = "0.0.0", path = "../monarch_rdma/extension", optional = true } -monarch_simulator_lib = { version = "0.0.0", path = "../monarch_simulator", optional = true } monarch_tensor_worker = { version = "0.0.0", path = "../monarch_tensor_worker", optional = true } monarch_types = { version = "0.0.0", path = "../monarch_types" } nccl-sys = { path = "../nccl-sys", optional = true } @@ -44,4 +42,4 @@ tracing = { version = "0.1.41", features = ["attributes", "valuable"] } [features] default = ["tensor_engine"] -tensor_engine = ["dep:controller", "dep:monarch_messages", "dep:monarch_rdma_extension", "dep:monarch_simulator_lib", "dep:monarch_tensor_worker", "dep:nccl-sys", "dep:rdmaxcel-sys", "dep:torch-sys", "dep:torch-sys-cuda"] +tensor_engine = ["dep:monarch_messages", "dep:monarch_rdma_extension", "dep:monarch_tensor_worker", "dep:nccl-sys", "dep:rdmaxcel-sys", "dep:torch-sys", "dep:torch-sys-cuda"] diff --git a/monarch_extension/src/lib.rs b/monarch_extension/src/lib.rs index 2faccdb2b..f149a893e 100644 --- a/monarch_extension/src/lib.rs +++ b/monarch_extension/src/lib.rs @@ -21,7 +21,6 @@ mod logging; #[cfg(feature = "tensor_engine")] mod mesh_controller; mod simulation_tools; -mod simulator_client; #[cfg(feature = "tensor_engine")] mod tensor_worker; @@ -121,14 +120,6 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> { module, "monarch_messages.debugger", )?)?; - simulator_client::register_python_bindings(&get_or_add_new_module( - module, - "monarch_extension.simulator_client", - )?)?; - ::controller::bootstrap::register_python_bindings(&get_or_add_new_module( - module, - "controller.bootstrap", - )?)?; ::monarch_tensor_worker::bootstrap::register_python_bindings(&get_or_add_new_module( module, "monarch_tensor_worker.bootstrap", diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs deleted file mode 100644 index 64834e060..000000000 --- a/monarch_extension/src/simulator_client.rs +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#![cfg(feature = "tensor_engine")] - -use std::sync::Arc; - -use anyhow::anyhow; -use hyperactor::WorldId; -use hyperactor::channel::ChannelAddr; -use hyperactor::simnet; -use hyperactor::simnet::TrainingScriptState; -use hyperactor::simnet::simnet_handle; -use monarch_hyperactor::runtime::signal_safe_block_on; -use monarch_simulator_lib::simulator::TensorEngineSimulator; -use pyo3::exceptions::PyRuntimeError; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; -use tokio::sync::Mutex; - -/// A wrapper around [ndslice::Slice] to expose it to python. -/// It is a compact representation of indices into the flat -/// representation of an n-dimensional array. Given an offset, sizes of -/// each dimension, and strides for each dimension, Slice can compute -/// indices into the flat array. -#[pyclass( - name = "SimulatorClient", - frozen, - module = "monarch._rust_bindings.monarch_extension.simulator_client" -)] -#[derive(Clone)] -pub(crate) struct SimulatorClient { - inner: Arc>, - world_size: usize, -} - -fn set_training_script_state(state: TrainingScriptState) -> PyResult<()> { - simnet_handle() - .map_err(|e| anyhow!(e))? - .set_training_script_state(state); - Ok(()) -} - -#[pymethods] -impl SimulatorClient { - #[new] - fn new(py: Python, system_addr: String, world_size: i32) -> PyResult { - signal_safe_block_on(py, async move { - simnet::start(); - - Ok(Self { - inner: Arc::new(Mutex::new( - TensorEngineSimulator::new( - system_addr - .parse::() - .map_err(|err| PyValueError::new_err(err.to_string()))?, - ) - .await - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, - )), - world_size: world_size as usize, - }) - })? - } - - fn kill_world(&self, py: Python, world_name: &str) -> PyResult<()> { - let simulator = self.inner.clone(); - let world_name = world_name.to_string(); - - signal_safe_block_on(py, async move { - simulator - .lock() - .await - .kill_world(&world_name) - .map_err(|err| anyhow!(err))?; - Ok(()) - })? - } - - fn spawn_mesh( - &self, - py: Python, - system_addr: &str, - controller_actor_id: &str, - worker_world: &str, - ) -> PyResult<()> { - let simulator = self.inner.clone(); - let world_size = self.world_size; - let system_addr = system_addr.parse::().unwrap(); - let worker_world = worker_world.parse::().unwrap(); - let controller_actor_id = controller_actor_id.parse().unwrap(); - - signal_safe_block_on(py, async move { - simulator - .lock() - .await - .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) - .await - .map_err(|err| anyhow!(err))?; - Ok(()) - })? - } - - fn set_training_script_state_running(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Running) - } - - fn set_training_script_state_waiting(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Waiting) - } -} - -pub(crate) fn register_python_bindings(simulator_client_mod: &Bound<'_, PyModule>) -> PyResult<()> { - simulator_client_mod.add_class::()?; - Ok(()) -} diff --git a/monarch_simulator/Cargo.toml b/monarch_simulator/Cargo.toml deleted file mode 100644 index f3f3732cd..000000000 --- a/monarch_simulator/Cargo.toml +++ /dev/null @@ -1,35 +0,0 @@ -# @generated by autocargo from //monarch/monarch_simulator:monarch_simulator_lib - -[package] -name = "monarch_simulator_lib" -version = "0.0.0" -authors = ["Meta"] -edition = "2021" -license = "BSD-3-Clause" - -[lib] -edition = "2024" - -[dependencies] -anyhow = "1.0.98" -async-trait = "0.1.86" -controller = { version = "0.0.0", path = "../controller" } -dashmap = { version = "5.5.3", features = ["rayon", "serde"] } -futures = { version = "0.3.31", features = ["async-await", "compat"] } -hyperactor = { version = "0.0.0", path = "../hyperactor" } -hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" } -lazy_static = "1.5" -monarch_messages = { version = "0.0.0", path = "../monarch_messages" } -monarch_tensor_worker = { version = "0.0.0", path = "../monarch_tensor_worker" } -monarch_types = { version = "0.0.0", path = "../monarch_types" } -ndslice = { version = "0.0.0", path = "../ndslice" } -serde = { version = "1.0.219", features = ["derive", "rc"] } -serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "raw_value", "unbounded_depth"] } -thiserror = "2.0.12" -tokio = { version = "1.47.1", features = ["full", "test-util", "tracing"] } -torch-sys = { version = "0.0.0", path = "../torch-sys" } -torch-sys-cuda = { version = "0.0.0", path = "../torch-sys-cuda" } -tracing = { version = "0.1.41", features = ["attributes", "valuable"] } - -[dev-dependencies] -tracing-test = { version = "0.2.3", features = ["no-env-filter"] } diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs deleted file mode 100644 index 58fc31c36..000000000 --- a/monarch_simulator/src/bootstrap.rs +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use anyhow::Result; -use controller::bootstrap::bootstrap_controller; -use hyperactor::ActorHandle; -use hyperactor::ActorId; -use hyperactor::ActorRef; -use hyperactor::ProcId; -use hyperactor::WorldId; -use hyperactor::channel::ChannelAddr; -use hyperactor::channel::sim::SimAddr; -use hyperactor_multiprocess::System; -use hyperactor_multiprocess::proc_actor::ProcActor; -use hyperactor_multiprocess::proc_actor::spawn; -use hyperactor_multiprocess::system::ServerHandle; -use hyperactor_multiprocess::system_actor::ProcLifecycleMode; -use monarch_messages::worker::Factory; -use torch_sys::Layout; -use torch_sys::ScalarType; - -use crate::worker::Fabric; -use crate::worker::MockWorkerParams; -use crate::worker::WorkerActor; - -/// spawns the system. -#[tracing::instrument("spawn_system")] -pub async fn spawn_system(system_addr: ChannelAddr) -> Result { - // TODO: pass in as args - let supervision_update_timeout = Duration::from_secs(120); - let world_eviction_timeout = Duration::from_secs(120); - - let handle = System::serve( - system_addr.clone(), - supervision_update_timeout, - world_eviction_timeout, - ) - .await?; - Ok(handle) -} - -/// Spawns the controller proc and actor. -#[tracing::instrument("spawn_controller")] -pub async fn spawn_controller( - bootstrap_addr: ChannelAddr, - controller_actor_id: ActorId, - worker_world_id: WorldId, - world_size: usize, -) -> anyhow::Result> { - let listen_addr = ChannelAddr::any(bootstrap_addr.transport()); - let ChannelAddr::Sim(bootstrap_addr) = bootstrap_addr else { - panic!("bootstrap_addr must be a SimAddr"); - }; - let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), - ); - tracing::info!( - "controller listen addr: {}, bootstrap addr: {}", - &listen_addr, - &bootstrap_addr - ); - - let worker_name = "worker"; - let supervision_query_interval = Duration::from_secs(2); - let supervision_update_interval = Duration::from_secs(2); - let worker_progress_check_interval = Duration::from_secs(10); - let operation_timeout = Duration::from_secs(120); - let operations_per_worker_progress_request = 100; - let proc_actor_handle = bootstrap_controller( - bootstrap_addr, - Some(listen_addr), - controller_actor_id, - world_size, - worker_world_id.clone(), - worker_name.to_string(), - supervision_query_interval, - supervision_update_interval, - worker_progress_check_interval, - operation_timeout, - operations_per_worker_progress_request, - None, /* extra_controller_labels */ - false, /* fail_on_worker_timeout */ - ) - .await?; - - Ok(proc_actor_handle) -} - -/// Spawns workers. Right now, only one mocked worker is spawned. TODO: spawn multiple workers. -#[tracing::instrument("spawn_worker")] -pub async fn spawn_sim_worker( - bootstrap_addr: ChannelAddr, - worker_world_id: WorldId, - controller_actor_id: ActorId, - world_size: usize, - rank: usize, -) -> anyhow::Result> { - let listen_addr = ChannelAddr::any(bootstrap_addr.transport()); - let worker_proc_id = ProcId::Ranked(worker_world_id.clone(), rank); - let worker_actor_id = ActorId(worker_proc_id.clone(), "worker".into(), 0); - - let ChannelAddr::Sim(bootstrap_addr) = bootstrap_addr else { - panic!("bootstrap_addr must be a SimAddr"); - }; - let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), - ); - tracing::info!( - "worker {} listen addr: {}, bootstrap addr: {}", - &worker_actor_id, - &listen_addr, - &bootstrap_addr - ); - - let supervision_update_interval = Duration::from_secs(10); - let bootstrap = ProcActor::bootstrap( - worker_proc_id, - worker_world_id, - listen_addr, - bootstrap_addr.clone(), - supervision_update_interval, - HashMap::new(), - ProcLifecycleMode::ManagedBySystem, - ) - .await?; - let mut system = hyperactor_multiprocess::System::new(bootstrap_addr); - let client = system.attach().await?; - let fabric = Arc::new(Fabric::new()); - let factory = Factory { - size: vec![2, 3], - dtype: ScalarType::Float, - layout: Layout::Strided, - device: "cpu".try_into().unwrap(), - }; - let controller_actor_ref = ActorRef::attest(controller_actor_id); - let params = MockWorkerParams::new( - worker_actor_id.rank(), - worker_actor_id.clone(), - fabric.clone(), - factory.clone(), - 2, - controller_actor_ref, - ); - let _worker_actor_ref = spawn::( - &client, - &bootstrap.proc_actor.bind(), - worker_actor_id.name(), - ¶ms, - ) - .await?; - Ok(bootstrap.proc_actor) -} diff --git a/monarch_simulator/src/collective_coordinator.rs b/monarch_simulator/src/collective_coordinator.rs deleted file mode 100644 index 073d83a11..000000000 --- a/monarch_simulator/src/collective_coordinator.rs +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::collections::HashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; - -use hyperactor::WorldId; -use lazy_static::lazy_static; -use tokio::sync::Mutex; -use tokio::sync::oneshot; -use tokio::sync::oneshot::Receiver; -use tokio::sync::oneshot::Sender; - -use crate::SimulatorError; - -lazy_static! { - /// A handle for SimNet through which you can send and schedule events in the - /// network. - static ref COLLECTIVE_COORDINATOR: CollectiveCoorindator = CollectiveCoorindator::new(); -} - -#[derive(Debug, PartialEq)] -enum MeshState { - Healthy, - Unhealthy, -} - -#[derive(Debug, PartialEq)] -pub enum CollectiveResult { - /// Collective has completed successfully - Done, - /// One or more peers are unavailable - PeerUnavailable, -} - -#[derive(Debug)] -struct CollectiveCoorindator { - // TODO(lky): revisit this in the future to support multiple workers in a mesh - /// A map from worker world id to mesh state - meshes: Arc>>, - collective_counter: AtomicUsize, - /// A flag to synchronize participants to the same phase. - step: Arc>, - result_senders: Arc>>>, -} - -impl CollectiveCoorindator { - fn new() -> Self { - Self { - meshes: Arc::new(Mutex::new(HashMap::new())), - collective_counter: AtomicUsize::new(0), - result_senders: Arc::new(Mutex::new(vec![])), - step: Arc::new(Mutex::new("".to_string())), - } - } - - #[allow(dead_code)] - async fn deactivate_mesh(&self, world_id: WorldId) -> Result<(), SimulatorError> { - let mut meshes = self.meshes.lock().await; - let mesh_mut = meshes - .get_mut(&world_id) - .ok_or(SimulatorError::MeshNotFound(world_id.to_string()))?; - *mesh_mut = MeshState::Unhealthy; - Ok(()) - } - - async fn activate_mesh(&self, world_id: WorldId, step: &str) { - let mut current_step = self.step.lock().await; - if *current_step != step { - *current_step = step.to_string(); - self.meshes.lock().await.clear(); - self.collective_counter.store(0, Ordering::SeqCst); - self.result_senders.lock().await.clear(); - } - self.meshes - .lock() - .await - .insert(world_id, MeshState::Healthy); - } - - /// Run the collective. Once the collective is complete, the result will be sent to the result_tx channel. - async fn collect(&self) -> Receiver { - let (result_tx, result_rx) = oneshot::channel::(); - self.collective_counter.fetch_add(1, Ordering::SeqCst); - self.result_senders.lock().await.push(result_tx); - // If any of the mesh is unhealthy, we should fail the collective. - if self - .meshes - .lock() - .await - .values() - .any(|mesh| mesh == &MeshState::Unhealthy) - { - for result_tx in self.result_senders.lock().await.drain(..) { - // Fail to send result back should not be reported back to the caller. - // Since the caller that triggers the send is the last caller. - if let Err(e) = result_tx.send(CollectiveResult::PeerUnavailable) { - tracing::error!("failed to send result back to caller: {:?}", e); - } - } - } - if self.collective_counter.load(Ordering::SeqCst) == self.meshes.lock().await.len() { - self.collective_counter.store(0, Ordering::SeqCst); - for result_tx in self.result_senders.lock().await.drain(..) { - // Fail to send result back should not be reported back to the caller. - // Since the caller that triggers the send is the last caller. - if let Err(e) = result_tx.send(CollectiveResult::Done) { - tracing::error!("failed to send result back to caller: {:?}", e); - } - } - } - result_rx - } - - async fn is_active(&self, world_id: WorldId) -> bool { - self.meshes - .lock() - .await - .get(&world_id) - .unwrap_or(&MeshState::Unhealthy) - == &MeshState::Healthy - } -} - -pub async fn activate_mesh(world_id: WorldId, step: &str) { - COLLECTIVE_COORDINATOR.activate_mesh(world_id, step).await; -} - -pub async fn collect() -> Receiver { - COLLECTIVE_COORDINATOR.collect().await -} - -#[allow(dead_code)] -pub async fn deactivate_mesh(world_id: WorldId) -> Result<(), SimulatorError> { - COLLECTIVE_COORDINATOR.deactivate_mesh(world_id).await -} - -pub async fn is_active(world_id: WorldId) -> bool { - COLLECTIVE_COORDINATOR.is_active(world_id).await -} - -#[cfg(test)] -mod tests { - - use std::time::Duration; - - use hyperactor::id; - use tokio::time::timeout; - - use super::*; - - #[tokio::test] - async fn test_collective_coordinator_success() { - let world_0 = id!(world_0); - let world_1 = id!(world_1); - - let collective_coordinator = CollectiveCoorindator::new(); - collective_coordinator.activate_mesh(world_0, "1").await; - collective_coordinator.activate_mesh(world_1, "1").await; - - let mut result_rx_0 = collective_coordinator.collect().await; - - // Assert that the collective will timeout after 1 second, since world_1 did not call for collect yet. - assert!( - timeout(Duration::from_secs(1), &mut result_rx_0) - .await - .is_err() - ); - - let result_rx_1 = collective_coordinator.collect().await; - - assert_eq!(result_rx_0.await.unwrap(), CollectiveResult::Done); - assert_eq!(result_rx_1.await.unwrap(), CollectiveResult::Done); - } - - #[tokio::test] - async fn test_collective_coordinator_failure() { - let world_0 = id!(world_0); - let world_1 = id!(world_1); - - let collective_coordinator = CollectiveCoorindator::new(); - collective_coordinator.activate_mesh(world_0, "1").await; - collective_coordinator - .activate_mesh(world_1.clone(), "1") - .await; - collective_coordinator - .deactivate_mesh(world_1) - .await - .unwrap(); - - let result_rx_0 = collective_coordinator.collect().await; - - assert_eq!( - result_rx_0.await.unwrap(), - CollectiveResult::PeerUnavailable - ); - } -} diff --git a/monarch_simulator/src/controller.rs b/monarch_simulator/src/controller.rs deleted file mode 100644 index bc6db4b82..000000000 --- a/monarch_simulator/src/controller.rs +++ /dev/null @@ -1,290 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! An implementation of mocked controller that can be used for simulation. - -use std::collections::HashMap; -use std::time::Duration; - -use async_trait::async_trait; -use hyperactor::Actor; -use hyperactor::ActorId; -use hyperactor::ActorRef; -use hyperactor::Context; -use hyperactor::Named; -use hyperactor::actor::ActorHandle; -use hyperactor::attrs::Attrs; -use hyperactor::channel::ChannelAddr; -use hyperactor::data::Serialized; -use hyperactor_multiprocess::proc_actor::ProcActor; -use hyperactor_multiprocess::proc_actor::spawn; -use hyperactor_multiprocess::system_actor::ProcLifecycleMode; -use monarch_messages::client::ClientActor; -use monarch_messages::client::ClientMessageClient; -use monarch_messages::controller::WorkerError; -use monarch_messages::controller::*; -use monarch_messages::debugger::DebuggerAction; -use monarch_messages::worker::Ref; -use monarch_messages::worker::WorkerMessage; -use serde::Deserialize; -use serde::Serialize; -use tokio::sync::OnceCell; - -use crate::worker::WorkerActor; - -#[derive(Debug)] -#[hyperactor::export( - spawn = true, - handlers = [ - ControllerMessage, - ], -)] -pub struct SimControllerActor { - client_actor_ref: OnceCell>, - worker_actor_ref: ActorRef, - /// A light weight map from seq to result. - history: HashMap>, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Named)] -pub struct SimControllerParams { - worker_actor_id: ActorId, -} - -impl SimControllerParams { - pub fn new(worker_actor_id: ActorId) -> Self { - // Only a single worker is created here to simulate a gang of workers. - Self { worker_actor_id } - } -} - -#[async_trait] -impl Actor for SimControllerActor { - type Params = SimControllerParams; - - async fn new(params: SimControllerParams) -> Result { - Ok(Self { - client_actor_ref: OnceCell::new(), - worker_actor_ref: ActorRef::attest(params.worker_actor_id), - history: HashMap::new(), - }) - } - - async fn init(&mut self, _this: &hyperactor::Instance) -> Result<(), anyhow::Error> { - Ok(()) - } -} - -#[async_trait] -#[hyperactor::forward(ControllerMessage)] -impl ControllerMessageHandler for SimControllerActor { - async fn attach( - &mut self, - _cx: &Context, - client_actor: ActorRef, - ) -> Result<(), anyhow::Error> { - self.client_actor_ref - .set(client_actor) - .map_err(|actor_ref| anyhow::anyhow!("client actor {} already attached", actor_ref)) - } - - async fn node( - &mut self, - _cx: &Context, - _seq: Seq, - _uses: Vec, - _defs: Vec, - ) -> Result<(), anyhow::Error> { - tracing::info!("controller node: uses: {:?}, defs: {:?}", _uses, _defs); - Ok(()) - } - - async fn send( - &mut self, - cx: &Context, - ranks: Ranks, - message: Serialized, - ) -> Result<(), anyhow::Error> { - tracing::info!("controller send to ranks {:?}: {}", ranks, message); - self.worker_actor_ref - .port::() - .send_serialized(cx, Attrs::new(), message); - Ok(()) - } - - async fn remote_function_failed( - &mut self, - _cx: &Context, - _seq: Seq, - _error: WorkerError, - ) -> Result<(), anyhow::Error> { - Ok(()) - } - - async fn status( - &mut self, - cx: &Context, - seq: Seq, - _worker_actor_id: ActorId, - controller: bool, - ) -> Result<(), anyhow::Error> { - tracing::info!( - "controller status: seq {}, worker_actor_id {:?}, controller: {:?}", - &seq, - _worker_actor_id, - controller - ); - tracing::info!("controller history in status(): {:?}", &self.history); - let result = self.history.remove(&seq).unwrap(); - let client = self.client_actor_ref.get().unwrap(); - if let Err(e) = client - .result(cx, seq.clone(), Some(Ok(result.unwrap()))) - .await - { - tracing::error!("controller failed to send result: {:?}", e); - } - Ok(()) - } - - async fn fetch_result( - &mut self, - _cx: &Context, - seq: Seq, - result: Result, - ) -> Result<(), anyhow::Error> { - tracing::info!( - "controller fetch result: seq {}, result {}", - &seq, - &result.as_ref().unwrap() - ); - self.history.insert(seq, result); - tracing::info!("controller history in fetch_result: {:?}", &self.history); - Ok(()) - } - - async fn check_supervision(&mut self, _cx: &Context) -> Result<(), anyhow::Error> { - Ok(()) - } - - async fn drop_refs( - &mut self, - _cx: &Context, - _refs: Vec, - ) -> Result<(), anyhow::Error> { - Ok(()) - } - - async fn debugger_message( - &mut self, - _cx: &Context, - _debugger_actor_id: ActorId, - _action: DebuggerAction, - ) -> Result<(), anyhow::Error> { - Ok(()) - } - - #[cfg(test)] - async fn get_first_incomplete_seqs_unit_tests_only( - &mut self, - _cx: &Context, - ) -> Result, anyhow::Error> { - Ok(vec![]) - } - - #[cfg(not(test))] - async fn get_first_incomplete_seqs_unit_tests_only( - &mut self, - _cx: &Context, - ) -> Result, anyhow::Error> { - unimplemented!("get_first_incomplete_seqs_unit_tests_only is only for unit tests") - } -} - -impl SimControllerActor { - /// Bootstrap the controller actor. This will create a new proc, join the system at `bootstrap_addr` - /// and spawn the controller actor into the proc. - pub async fn bootstrap( - controller_id: ActorId, - listen_addr: ChannelAddr, - bootstrap_addr: ChannelAddr, - params: SimControllerParams, - supervision_update_interval: Duration, - ) -> Result<(ActorHandle, ActorRef), anyhow::Error> { - let bootstrap = ProcActor::bootstrap( - controller_id.proc_id().clone(), - controller_id - .proc_id() - .world_id() - .expect("sim controller only works on ranked procs") - .clone(), // REFACTOR(marius): plumb world id through SimControllerActor::bootstrap - listen_addr, - bootstrap_addr.clone(), - supervision_update_interval, - HashMap::new(), - ProcLifecycleMode::ManagedBySystem, - ) - .await?; - - let mut system = hyperactor_multiprocess::System::new(bootstrap_addr); - let client = system.attach().await?; - - let controller_actor_ref = spawn::( - &client, - &bootstrap.proc_actor.bind(), - controller_id.clone().name(), - ¶ms, - ) - .await?; - - Ok((bootstrap.proc_actor, controller_actor_ref)) - } - - #[allow(dead_code)] - fn client(&self) -> Result, anyhow::Error> { - self.client_actor_ref - .get() - .ok_or_else(|| anyhow::anyhow!("client actor ref not set")) - .cloned() - } -} - -#[cfg(test)] -mod tests { - use hyperactor::channel::ChannelTransport; - use hyperactor::id; - use hyperactor_multiprocess::System; - - use super::*; - - #[tokio::test] - async fn test_bootstrap() { - let server_handle = System::serve( - ChannelAddr::any(ChannelTransport::Local), - Duration::from_secs(10), - Duration::from_secs(10), - ) - .await - .unwrap(); - - let controller_id = id!(controller[0].root); - let (proc_handle, actor_ref) = SimControllerActor::bootstrap( - controller_id.clone(), - ChannelAddr::any(ChannelTransport::Local), - server_handle.local_addr().clone(), - SimControllerParams { - worker_actor_id: id!(worker[0].root), - }, - Duration::from_secs(1), - ) - .await - .unwrap(); - assert_eq!(*actor_ref.actor_id(), controller_id); - - proc_handle.drain_and_stop().unwrap(); - } -} diff --git a/monarch_simulator/src/lib.rs b/monarch_simulator/src/lib.rs deleted file mode 100644 index 8706bcaf3..000000000 --- a/monarch_simulator/src/lib.rs +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use hyperactor::actor::ActorError; -use hyperactor::simnet::SimNetError; - -pub mod bootstrap; -mod collective_coordinator; -pub mod controller; -pub mod simulator; -pub mod worker; - -/// The type of error that can occur on channel operations. -#[derive(thiserror::Error, Debug)] -pub enum SimulatorError { - /// Error during simnet operation. - #[error(transparent)] - SimNetError(#[from] SimNetError), - - /// Error during actor operations. - #[error(transparent)] - ActorError(#[from] ActorError), - - /// Simulator cannot find the world with given name. - #[error("World {0} not found")] - WorldNotFound(String), - - /// Cannot find the mesh in simulator. - #[error("Mesh not found {0}")] - MeshNotFound(String), -} diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs deleted file mode 100644 index 174e0d133..000000000 --- a/monarch_simulator/src/simulator.rs +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::collections::HashMap; -use std::future::IntoFuture; - -use anyhow::Result; -use futures::FutureExt; -use futures::future::BoxFuture; -use hyperactor::ActorHandle; -use hyperactor::ActorId; -use hyperactor::WorldId; -use hyperactor::channel::ChannelAddr; -use hyperactor_multiprocess::proc_actor::ProcActor; -use hyperactor_multiprocess::system::ServerHandle; - -use crate::SimulatorError; -use crate::bootstrap::spawn_controller; -use crate::bootstrap::spawn_sim_worker; -use crate::bootstrap::spawn_system; - -/// The simulator manages all of the meshes and the system handle. -#[derive(Debug)] -pub struct TensorEngineSimulator { - /// A map from world name to actor handles in that world. - worlds: HashMap>>, - system_handle: ServerHandle, -} - -impl TensorEngineSimulator { - pub async fn new(system_addr: ChannelAddr) -> Result { - Ok(Self { - worlds: HashMap::new(), - system_handle: spawn_system(system_addr).await?, - }) - } - - pub async fn spawn_mesh( - &mut self, - system_addr: ChannelAddr, - controller_actor_id: ActorId, - worker_world_id: WorldId, - world_size: usize, - ) -> Result<()> { - let controller = spawn_controller( - system_addr.clone(), - controller_actor_id.clone(), - worker_world_id.clone(), - world_size, - ) - .await?; - self.worlds.insert( - controller_actor_id.world_name().to_string(), - vec![controller], - ); - - for rank in 0..world_size { - let worker = spawn_sim_worker( - system_addr.clone(), - worker_world_id.clone(), - controller_actor_id.clone(), - world_size, - rank, - ) - .await?; - self.worlds - .entry(worker_world_id.name().to_string()) - .or_insert(vec![]) - .push(worker); - } - Ok(()) - } - - /// Kills the actors within the given world. - /// Returns error if there's no world found in the current simulator. - pub fn kill_world(&mut self, world_name: &str) -> Result<(), SimulatorError> { - let actors = self - .worlds - .remove(world_name) - .ok_or(SimulatorError::WorldNotFound(world_name.to_string()))?; - for actor in actors { - actor.drain_and_stop()?; - } - Ok(()) - } -} - -/// IntoFuture allows users to await the handle. The future resolves when -/// the simulator itself has all of the actor handles stopped. -impl IntoFuture for TensorEngineSimulator { - type Output = (); - type IntoFuture = BoxFuture<'static, Self::Output>; - - fn into_future(self) -> Self::IntoFuture { - let future = async move { - self.system_handle.await; - for actors in self.worlds.into_values() { - for actor in actors { - actor.await; - } - } - }; - - future.boxed() - } -} - -#[cfg(test)] -mod tests { - use hyperactor::ActorId; - use hyperactor::ProcId; - use hyperactor::WorldId; - use hyperactor::channel::ChannelAddr; - use hyperactor::simnet; - - #[tracing_test::traced_test] - #[tokio::test] - async fn test_spawn_and_kill_mesh() { - simnet::start(); - - let system_addr = "sim!unix!@system".parse::().unwrap(); - let mut simulator = super::TensorEngineSimulator::new(system_addr.clone()) - .await - .unwrap(); - let mut controller_actor_ids = vec![]; - let mut worker_actor_ids = vec![]; - let n_meshes = 2; - for i in 0..n_meshes { - let controller_world_name = format!("controller_world_{}", i); - let worker_world_name = format!("worker_world_{}", i); - controller_actor_ids.push(ActorId( - ProcId::Ranked(WorldId(controller_world_name), 0), - "root".into(), - 0, - )); - worker_actor_ids.push(ActorId( - ProcId::Ranked(WorldId(worker_world_name.clone()), 0), - "root".into(), - 0, - )); - simulator - .spawn_mesh( - system_addr.clone(), - controller_actor_ids.last().unwrap().clone(), - WorldId(worker_world_name), - 1, - ) - .await - .unwrap(); - } - - assert_eq!(simulator.worlds.len(), n_meshes * 2); - let world_name = controller_actor_ids[0].world_name(); - let controller_actor_handle = simulator.worlds.get(world_name).unwrap().first().unwrap(); - assert_eq!(controller_actor_handle.actor_id().world_name(), world_name); - - simulator.kill_world(world_name).unwrap(); - assert_eq!(simulator.worlds.len(), n_meshes * 2 - 1); - } -} diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs deleted file mode 100644 index 42a79d6fb..000000000 --- a/monarch_simulator/src/worker.rs +++ /dev/null @@ -1,941 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -use std::collections::HashMap; -use std::ops::Add; -use std::sync::Arc; - -use anyhow::Context; -use anyhow::Result; -use anyhow::bail; -use anyhow::ensure; -use async_trait::async_trait; -use dashmap::DashMap; -use hyperactor::Actor; -use hyperactor::ActorRef; -use hyperactor::Named; -use hyperactor::data::Serialized; -use hyperactor::forward; -use hyperactor::reference::ActorId; -use hyperactor::simnet::TorchOpEvent; -use hyperactor::simnet::simnet_handle; -use monarch_messages::controller::ControllerActor; -use monarch_messages::controller::ControllerMessageClient; -use monarch_messages::controller::Seq; -use monarch_messages::controller::WorkerError; -use monarch_messages::wire_value::WireValue; -use monarch_messages::worker::*; -use monarch_tensor_worker::device_mesh::DeviceMesh; -use monarch_types::PyTree; -use ndslice::Slice; -use serde::Deserialize; -use serde::Serialize; -use tokio::sync::Mutex; -use tokio::sync::mpsc; -use tokio::sync::oneshot; -use torch_sys::Device; -use torch_sys::DeviceType; -use torch_sys::Layout; -use torch_sys::RValue; -use torch_sys::ScalarType; -use torch_sys::Tensor; -use torch_sys::TensorCell; -use torch_sys::factory_empty; -use torch_sys::factory_zeros; -use torch_sys_cuda::nccl::NcclConfig; -use torch_sys_cuda::nccl::ReduceOp; -use torch_sys_cuda::nccl::UniqueId; - -use crate::collective_coordinator::CollectiveResult; -use crate::collective_coordinator::activate_mesh; -use crate::collective_coordinator::collect; -use crate::collective_coordinator::is_active; - -type Channel = ( - mpsc::UnboundedSender, - Arc>>, -); - -/// A fake backend network to support sending tensors between nodes. -#[derive(Debug, Deserialize, Serialize)] -pub struct Fabric { - #[serde(skip)] - inputs: DashMap<(StreamRef, usize), Channel>, - #[serde(skip)] - outputs: DashMap<(StreamRef, usize), Channel>, -} - -const PAFT_RECONFIG_FTAR_FCN: &str = "paft.paft_worker.reconfig_ftar"; -const PAFT_RUN_ALLREDUCE_FCN: &str = "paft.paft_worker.run_allreduce"; - -impl Fabric { - pub fn new() -> Self { - Self { - inputs: DashMap::new(), - outputs: DashMap::new(), - } - } - - fn put_input(&self, stream: StreamRef, from: usize, tensor: TensorCell) -> Result<()> { - let sender = self - .inputs - .entry((stream, from)) - .or_insert_with(|| { - let (s, r) = mpsc::unbounded_channel(); - (s, Arc::new(Mutex::new(r))) - }) - .0 - .clone(); - sender.send(tensor)?; - Ok(()) - } - - async fn get_input(&self, stream: StreamRef, from: usize) -> Result { - let recv = self - .inputs - .entry((stream, from)) - .or_insert_with(|| { - let (s, r) = mpsc::unbounded_channel(); - (s, Arc::new(Mutex::new(r))) - }) - .1 - .clone(); - let mut recv = recv.lock().await; - recv.recv().await.context("channel closed") - } - - fn put_output(&self, stream: StreamRef, to: usize, tensor: TensorCell) -> Result<()> { - let sender = self - .outputs - .entry((stream, to)) - .or_insert_with(|| { - let (s, r) = mpsc::unbounded_channel(); - (s, Arc::new(Mutex::new(r))) - }) - .0 - .clone(); - sender.send(tensor)?; - Ok(()) - } - - async fn get_output(&self, stream: StreamRef, to: usize) -> Result { - let recv = self - .outputs - .entry((stream, to)) - .or_insert_with(|| { - let (s, r) = mpsc::unbounded_channel(); - (s, Arc::new(Mutex::new(r))) - }) - .1 - .clone(); - let mut recv = recv.lock().await; - recv.recv().await.context("channel closed") - } -} - -fn reduce_op>( - _op: ReduceOp, - _output: &mut Tensor, - _inputs: &[TensorCell], -) -> Result<()> { - // TODO(agallagher): Do we need an impl for this? - Ok(()) -} - -#[derive(Debug)] -#[hyperactor::export( - spawn = true, - handlers = [ - WorkerMessage { cast = true }, - ], -)] -pub struct WorkerActor { - rank: usize, - worker_actor_id: ActorId, - fabric: Arc, - /// factory to use to create fake tensors. - factory: Factory, - /// the dimensions used for fake tensors. - dims: i64, - device_meshes: HashMap, - env: HashMap, - pipes: HashMap, - controller_actor_ref: ActorRef, - worker_error: Option, -} - -#[derive(Clone, Debug, Named, Serialize, Deserialize)] -pub struct MockWorkerParams { - rank: usize, - worker_actor_id: ActorId, - fabric: Arc, - factory: Factory, - dims: i64, - controller_actor_ref: ActorRef, -} - -impl MockWorkerParams { - pub fn new( - rank: usize, - worker_actor_id: ActorId, - fabric: Arc, - factory: Factory, - dims: i64, - controller_actor_ref: ActorRef, - ) -> Self { - Self { - rank, - worker_actor_id, - fabric, - factory, - dims, - controller_actor_ref, - } - } -} - -#[async_trait] -impl Actor for WorkerActor { - type Params = MockWorkerParams; - - async fn new( - MockWorkerParams { - rank, - worker_actor_id, - fabric, - factory, - dims, - controller_actor_ref, - }: Self::Params, - ) -> Result { - Ok(Self { - rank, - worker_actor_id, - fabric, - factory, - dims, - device_meshes: HashMap::new(), - env: HashMap::new(), - pipes: HashMap::new(), - controller_actor_ref, - worker_error: None, - }) - } -} - -#[async_trait] -#[forward(WorkerMessage)] -impl WorkerMessageHandler for WorkerActor { - async fn backend_network_init( - &mut self, - _cx: &hyperactor::Context, - _unique_id: UniqueId, - ) -> Result<()> { - Ok(()) - } - - async fn backend_network_point_to_point_init( - &mut self, - _cx: &hyperactor::Context, - _from_stream: StreamRef, - _to_stream: StreamRef, - ) -> Result<()> { - Ok(()) - } - - async fn call_function( - &mut self, - cx: &hyperactor::Context, - params: CallFunctionParams, - ) -> Result<()> { - tracing::info!("worker received call_function: {:#?}", ¶ms); - match ¶ms.function { - ResolvableFunction::FunctionPath(FunctionPath { path }) => { - tracing::info!("function path: {:#?}", &path); - if path == PAFT_RECONFIG_FTAR_FCN { - let step = match params.kwargs.get("step").unwrap() { - WireValue::PyObject(step) => Some(step.clone()), - _ => None, - }; - let step = step.unwrap(); - let serialized_step = serde_json::to_string(&step).unwrap(); - activate_mesh(self.worker_actor_id.world_name().parse()?, &serialized_step) - .await; - } - if path == PAFT_RUN_ALLREDUCE_FCN { - if !is_active(self.worker_actor_id.world_name().parse()?).await { - // Controller will send supervision failure message to controller. - panic!("worker is killed by user"); - } - let rx = collect().await; - let collective_result = rx.await; - if collective_result.unwrap() == CollectiveResult::PeerUnavailable { - // Send worker error to controller. - let worker_error = WorkerError { - backtrace: "AllReduce failed".to_string(), - worker_actor_id: self.worker_actor_id.clone(), - }; - self.worker_error = Some(worker_error); - return Ok(()); - } - } - } - _ => {} - } - for result in params.results.into_iter() { - if let Some(result) = result { - self.env.insert(result, self.mock_tensor()?); - } - } - match ¶ms.function.as_torch_op() { - Some((op, _)) => { - self.call_torch_op(op, params.args, params.kwargs, cx.self_id().clone()) - .await?; - } - _ => { - let _ = self.call_python_fn( - cx, - params.function, - params.args, - params.kwargs, - ¶ms.mutates, - ); - } - } - Ok(()) - } - - async fn send_result_of_actor_call( - &mut self, - _cx: &hyperactor::Context, - _params: ActorCallParams, - ) -> Result<()> { - bail!("unimplemented: send_result_of_actor_call"); - } - - async fn call_actor_method( - &mut self, - _cx: &hyperactor::Context, - _params: ActorMethodParams, - ) -> Result<()> { - bail!("unimplemented: call_actor_method"); - } - - async fn command_group( - &mut self, - cx: &hyperactor::Context, - params: Vec, - ) -> Result<()> { - for msg in params { - WorkerMessageHandler::handle(self, cx, msg).await?; - } - Ok(()) - } - - async fn create_stream( - &mut self, - _cx: &hyperactor::Context, - _result: StreamRef, - _creation_mode: StreamCreationMode, - ) -> Result<()> { - Ok(()) - } - - async fn create_device_mesh( - &mut self, - _cx: &hyperactor::Context, - result: Ref, - names: Vec, - ranks: Slice, - ) -> Result<()> { - self.device_meshes - .insert(result, DeviceMesh::new(names, ranks, self.rank)?); - Ok(()) - } - - async fn create_remote_process_group( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - _device_mesh: Ref, - _dims: Vec, - ) -> Result<()> { - bail!("unimplemented: create_remote_process_group") - } - - async fn borrow_create( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - _borrow_id: u64, - _tensor_ref: Ref, - _from_stream: StreamRef, - _to_stream: StreamRef, - ) -> Result<()> { - bail!("unimplemented: borrow_create") - } - - async fn borrow_first_use( - &mut self, - _cx: &hyperactor::Context, - _borrow: u64, - ) -> Result<()> { - bail!("unimplemented: borrow_first_use") - } - - async fn borrow_last_use( - &mut self, - _cx: &hyperactor::Context, - _borrow: u64, - ) -> Result<()> { - bail!("unimplemented: borrow_last_use") - } - - async fn borrow_drop( - &mut self, - _cx: &hyperactor::Context, - _borrow_id: u64, - ) -> Result<()> { - bail!("unimplemented: borrow_drop") - } - - async fn delete_refs( - &mut self, - _cx: &hyperactor::Context, - _refs: Vec, - ) -> Result<()> { - Ok(()) - } - - async fn request_status( - &mut self, - cx: &hyperactor::Context, - seq: Seq, - controller: bool, - ) -> Result<()> { - ControllerMessageClient::status( - &self.controller_actor_ref, - cx, - seq.next(), - cx.self_id().clone(), - controller, - ) - .await?; - Ok(()) - } - - async fn reduce( - &mut self, - _cx: &hyperactor::Context, - result: Ref, - local_tensor: Ref, - factory: Factory, - source_mesh: Ref, - stream_ref: StreamRef, - dims: Vec, - reduction: Reduction, - _scatter: bool, - in_place: bool, - out: Option, - ) -> Result<()> { - ensure!( - factory == self.factory, - "{:?} != {:?}", - factory, - self.factory - ); - - let mesh = self - .device_meshes - .get(&source_mesh) - .context("no such mesh")?; - let input_cell = self.env.get(&local_tensor).context("no such tensor")?; - - // Push input cell onto fabric. - self.fabric - .put_input(stream_ref, self.rank, input_cell.clone())?; - - let ranks_for_group = mesh.get_ranks_for_dim_slice(&dims)?; - - // Currentl impl has first rank doing all the work. - if self.rank == ranks_for_group[0] { - let mut inputs = Vec::new(); - for rank in ranks_for_group.iter() { - inputs.push(self.fabric.get_input(stream_ref, *rank).await?); - } - - // Create space for the result. - let sizes = [&[dims.len() as i64][..], &self.factory.size[..]].concat(); - let mut result = factory_empty( - &sizes, - self.factory.dtype, - self.factory.layout, - self.factory.device, - ); - - match reduction { - Reduction::ReduceOp(op) => match self.factory.dtype { - ScalarType::Float => reduce_op::(op, &mut result, inputs.as_slice())?, - ScalarType::Int => reduce_op::(op, &mut result, inputs.as_slice())?, - _ => bail!("unimplemented reduce op"), - }, - _ => bail!("unimplemented reduction"), - } - - let result_cell = TensorCell::new(result); - for rank in ranks_for_group.iter() { - self.fabric - .put_output(stream_ref, *rank, result_cell.clone())?; - } - } - - // Emit results. - let result_cell = self.fabric.get_output(stream_ref, self.rank).await?; - let result_tensor = result_cell.try_borrow().map_err(anyhow::Error::msg)?; - let output_cell = match (out, in_place) { - (None, false) => TensorCell::new(torch_sys::deep_clone(&result_tensor)), - (None, true) => { - let input = input_cell.try_borrow_mut().map_err(anyhow::Error::msg)?; - // SAFETY: ... - unsafe { - std::ptr::copy_nonoverlapping( - result_tensor.data_ptr(), - input.mut_data_ptr(), - input.nbytes(), - ) - }; - input_cell.clone() - } - _ => bail!("unimplemented output style"), - }; - self.env.insert(result, output_cell); - - Ok(()) - } - - async fn create_pipe( - &mut self, - _cx: &hyperactor::Context, - result: Ref, - _key: String, - _function: ResolvableFunction, - _max_messages: i64, - device_mesh: Ref, - _args: Vec, - _kwargs: HashMap, - ) -> Result<()> { - self.pipes.insert(result, device_mesh); - Ok(()) - } - - async fn send_tensor( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - _from_ranks: Slice, - _to_ranks: Slice, - _tensor: Ref, - _factory: Factory, - _from_stream: StreamRef, - _to_stream: StreamRef, - ) -> Result<()> { - bail!("unimplemented: send_tensor") - } - - async fn exit( - &mut self, - _cx: &hyperactor::Context, - _error: Option<(Option, String)>, - ) -> Result<()> { - Ok(()) - } - - async fn send_value( - &mut self, - cx: &hyperactor::Context, - seq: Seq, - _destination: Option, - _mutates: Vec, - _function: Option, - _args: Vec, - _kwargs: HashMap, - _stream: StreamRef, - ) -> Result<()> { - tracing::info!("worker received send_value"); - if let Some(worker_error) = self.worker_error.take() { - self.controller_actor_ref - .remote_function_failed(cx, seq, worker_error) - .await?; - return Ok(()); - } - - let tensor = factory_zeros( - &[1], - ScalarType::Float, - Layout::Strided, - Device::new(DeviceType { repr: 0 }), - ); - let rvalue = RValue::Tensor(TensorCell::new(tensor)); - let value = PyTree::from(rvalue); - let result = Ok(Serialized::serialize(&value)?); - self.controller_actor_ref - .fetch_result(cx, seq, result) - .await?; - Ok(()) - } - - async fn split_comm( - &mut self, - _cx: &hyperactor::Context, - _dims: Vec, - _device_mesh: Ref, - _stream_ref: StreamRef, - _config: Option, - ) -> Result<()> { - Ok(()) - } - - async fn split_comm_for_process_group( - &mut self, - _cx: &hyperactor::Context, - _remote_process_group_ref: Ref, - _stream_ref: StreamRef, - _config: Option, - ) -> Result<()> { - Ok(()) - } - - async fn pipe_recv( - &mut self, - _cx: &hyperactor::Context, - _seq: Seq, - results: Vec>, - pipe: Ref, - _stream: StreamRef, - ) -> Result<()> { - let mesh = self - .device_meshes - .get(self.pipes.get(&pipe).context("missing pipe")?) - .context("missing mesh")?; - ensure!(mesh.sizes().len() as i64 == self.dims); - for result in results.into_iter() { - if let Some(result) = result { - self.env.insert(result, self.mock_tensor()?); - } - } - Ok(()) - } - - async fn set_ref_unit_tests_only( - &mut self, - _cx: &hyperactor::Context, - _reference: Ref, - _value: WireValue, - _stream: StreamRef, - ) -> Result<()> { - bail!("unimplemented: set_ref_unit_tests_only") - } - - async fn get_ref_unit_tests_only( - &mut self, - _cx: &hyperactor::Context, - _ref_id: Ref, - _stream: StreamRef, - ) -> Result>> { - bail!("unimplemented: get_ref_unit_tests_only") - } - - async fn define_recording( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - _nresults: usize, - _nformals: usize, - _commands: Vec, - _ntotal_messages: usize, - _index: usize, - ) -> Result<()> { - unimplemented!() - } - - async fn recording_formal( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - _argument_index: usize, - _stream: StreamRef, - ) -> Result<()> { - unimplemented!() - } - - async fn recording_result( - &mut self, - _cx: &hyperactor::Context, - _result: Ref, - _output_index: usize, - _stream: StreamRef, - ) -> Result<()> { - unimplemented!() - } - - async fn call_recording( - &mut self, - _cx: &hyperactor::Context, - _seq: Seq, - _recording: Ref, - _results: Vec, - _actuals: Vec, - ) -> Result<()> { - unimplemented!() - } -} - -impl WorkerActor { - fn mock_tensor(&self) -> Result { - let sizes = [&[self.dims][..], &self.factory.size[..]].concat(); - let tensor = factory_empty( - &sizes, - self.factory.dtype, - self.factory.layout, - self.factory.device, - ); - Ok(TensorCell::new(tensor)) - } - - async fn call_torch_op( - &self, - op: &str, - args: Vec, - kwargs: HashMap, - actor_id: ActorId, - ) -> Result<()> { - let args_string = args - .iter() - .filter(|&wirevalue| wirevalue.is_ref()) - .map(|wirevalue| wirevalue.as_ref().unwrap().to_string()) - .collect::>() - .join(", "); - - let kwargs_string = kwargs - .iter() - .filter_map(|(k, wirevalue)| { - wirevalue - .is_ref() - .then(|| format!("{}={}", k, wirevalue.as_ref().unwrap())) - }) - .collect::>() - .join(", "); - - let (tx, rx) = oneshot::channel(); - - simnet_handle()? - .send_event(TorchOpEvent::new( - op.to_string(), - tx, - args_string, - kwargs_string, - actor_id, - )) - .unwrap(); - - rx.await.unwrap(); - - Ok(()) - } - - fn call_python_fn( - &mut self, - _cx: &hyperactor::Context, - _function: ResolvableFunction, - _args: Vec, - _kwargs: HashMap, - _mutates: &[Ref], - ) -> Result<()> { - bail!("unimplemented: call_python_fn") - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use anyhow::Result; - use futures::future::try_join_all; - use hyperactor::id; - use hyperactor::proc::Proc; - use hyperactor::simnet; - use monarch_types::PyTree; - use torch_sys::Layout; - use torch_sys::RValue; - use torch_sys::TensorCell; - use torch_sys::test_make_tensor; - - use super::*; - - #[tokio::test] - async fn test_all_reduce() -> Result<()> { - let proc = Proc::local(); - let _client = proc.attach("client")?; - - let world_size = 4; - let fabric = Arc::new(Fabric::new()); - let factory = Factory { - size: vec![2, 3], - dtype: ScalarType::Float, - layout: Layout::Strided, - device: "cpu".try_into()?, - }; - - let mut workers = vec![]; - for rank in 0..world_size { - workers.push( - proc.spawn::( - &format!("worker{}", rank), - MockWorkerParams { - rank, - worker_actor_id: id!(worker[0].root), - fabric: fabric.clone(), - factory: factory.clone(), - dims: 2, - controller_actor_ref: ActorRef::attest(id!(controller[0].root)), - }, - ) - .await?, - ); - } - - for worker in workers.into_iter() { - worker.drain_and_stop()?; - worker.await; - } - - Ok(()) - } - - #[tokio::test] - async fn worker_reduce() -> Result<()> { - simnet::start(); - let proc = Proc::local(); - let (client, _handle) = proc.instance("client")?; - - let world_size = 4; - let fabric = Arc::new(Fabric::new()); - let factory = Factory { - size: vec![2, 3], - dtype: ScalarType::Float, - layout: Layout::Strided, - device: "cpu".try_into()?, - }; - - let workers = try_join_all((0..world_size).map(async |rank| { - proc.spawn::( - &format!("worker{}", rank), - MockWorkerParams { - rank, - worker_actor_id: id!(worker[0].root), - fabric: fabric.clone(), - factory: factory.clone(), - dims: 2, - controller_actor_ref: ActorRef::attest(id!(controller[0].root)), - }, - ) - .await - })) - .await?; - - let unique_id = UniqueId::new()?; - let messages = vec![ - WorkerMessage::BackendNetworkInit(unique_id.clone()), - WorkerMessage::CreateStream { - id: 0.into(), - stream_creation: StreamCreationMode::UseDefaultStream, - }, - WorkerMessage::CreateDeviceMesh { - result: 1.into(), - names: vec!["x".into(), "y".into()], - ranks: Slice::new(0, vec![2, 2], vec![2, 1])?, - }, - WorkerMessage::CallFunction(CallFunctionParams { - seq: 0.into(), - results: vec![Some(2.into())], - mutates: vec![], - function: "torch.ops.aten.ones.default".into(), - args: vec![WireValue::IntList(vec![2, 3])], - kwargs: HashMap::from([("device".into(), WireValue::Device("cuda".try_into()?))]), - stream: 0.into(), - remote_process_groups: vec![], - }), - // Test reduce over "x". - WorkerMessage::Reduce { - result: 3.into(), - tensor: 2.into(), - factory: factory.clone(), - mesh: 1.into(), - stream: 0.into(), - dims: vec!["x".to_string()], - reduction: Reduction::ReduceOp(ReduceOp::Sum), - scatter: false, - in_place: false, - out: None, - }, - WorkerMessage::CallFunction(CallFunctionParams { - seq: 1.into(), - results: vec![Some(4.into())], - mutates: vec![], - function: "torch.ops.aten.full.default".into(), - args: vec![WireValue::IntList(vec![2, 3]), WireValue::Double(2.0)], - kwargs: HashMap::from([("device".into(), WireValue::Device("cuda".try_into()?))]), - stream: 0.into(), - remote_process_groups: vec![], - }), - WorkerMessage::CallFunction(CallFunctionParams { - seq: 1.into(), - results: vec![Some(5.into())], - mutates: vec![], - function: "torch.ops.aten.allclose.default".into(), - args: vec![WireValue::Ref(3.into()), WireValue::Ref(4.into())], - kwargs: HashMap::new(), - stream: 0.into(), - remote_process_groups: vec![], - }), - // Test reduce over "x" and "y". - WorkerMessage::Reduce { - result: 6.into(), - tensor: 2.into(), - factory, - mesh: 1.into(), - stream: 0.into(), - dims: vec!["x".into(), "y".into()], - reduction: Reduction::ReduceOp(ReduceOp::Sum), - scatter: false, - in_place: false, - out: None, - }, - ]; - - for worker in workers.iter() { - worker.command_group(&client, messages.clone()).await?; - } - - for worker in workers.into_iter() { - worker.drain_and_stop()?; - worker.await; - } - - Ok(()) - } - - #[tokio::test] - async fn test_create_tensor_pytree() -> Result<()> { - // A simple test case to show how to create a tensor pytree. - let tensor = test_make_tensor(); - let rvalue = RValue::Tensor(TensorCell::new(tensor)); - let _pytree = PyTree::from(&rvalue); - Ok(()) - } -} diff --git a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi deleted file mode 100644 index c2d529bd2..000000000 --- a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import final - -@final -class SimulatorClient: - """ - A wrapper around [simulator_client::Simulatorclient] to expose it to python. - It is a client to communicate with the simulator service. - - Arguments: - - `system_addr`: Address of the system. - - `world_size`: Number of workers in a given mesh. - """ - - def __init__(self, system_addr: str, world_size: int) -> None: ... - def kill_world(self, world_name: str) -> None: - """ - Kill the world with the given name. - - Arguments: - - `world_name`: Name of the world to kill. - """ - ... - def spawn_mesh( - self, system_addr: str, controller_actor_id: str, worker_world: str - ) -> None: - """ - Spawn a mesh actor. - - Arguments: - - `system_addr`: Address of the system to spawn the mesh in. - - `controller_actor_id`: Actor id of the controller to spawn the mesh in. - - `worker_world`: World of the worker to spawn the mesh in. - """ - ... - - def set_training_script_state_running(self) -> None: - """ - Let the simulator know that the training script is actively sending - commands to the backend - """ - ... - - def set_training_script_state_waiting(self) -> None: - """ - Let the simulator know that the training script is waiting for the - backend to resolve a future - """ - ... diff --git a/python/monarch/_testing.py b/python/monarch/_testing.py index 1c75bd7a9..80e11cd16 100644 --- a/python/monarch/_testing.py +++ b/python/monarch/_testing.py @@ -12,7 +12,6 @@ from contextlib import contextmanager, ExitStack from typing import Any, Callable, Dict, Generator, Literal, Optional -import monarch_supervisor from monarch._src.actor.endpoint import Extent from monarch._src.actor.host_mesh import create_local_host_mesh from monarch._src.actor.proc_mesh import proc_mesh, ProcMesh @@ -21,10 +20,7 @@ from monarch.common.client import Client from monarch.common.device_mesh import DeviceMesh from monarch.common.invocation import DeviceException, RemoteException -from monarch.controller.backend import ProcessBackend from monarch.mesh_controller import spawn_tensor_engine -from monarch.python_local_mesh import PythonLocalContext -from monarch.rust_local_mesh import LoggingLocation, ProcessCache from monarch.simulator.mock_controller import MockController @@ -48,31 +44,8 @@ class TestingContext: def __init__(self): self.cleanup = ExitStack() self._py_process_cache = {} - self._rust_process_cache = None self._proc_mesh_cache: Dict[Any, ProcMesh] = {} - @contextmanager - def _get_context(self, num_hosts, gpu_per_host): - # since we are local, there isn't a lot of latency involved. - # Make the host managers exit if they go 0.5 seconds without - # hearing from supervisor. - monarch_supervisor.HEARTBEAT_INTERVAL = 1 - ctx = PythonLocalContext(N=num_hosts) - store = ProcessBackend._create_store() - processes = ProcessBackend._create_pg( - ctx.ctx, ctx.hosts, gpu_per_host, store, _restartable=True - ) - yield ctx.ctx, ctx.hosts, processes - ctx.shutdown() - - def _processes(self, num_hosts, gpu_per_host): - key = (num_hosts, gpu_per_host) - if key not in self._py_process_cache: - self._py_process_cache[key] = self.cleanup.enter_context( - self._get_context(num_hosts, gpu_per_host) - ) - return self._py_process_cache[key] - @contextmanager def local_engine_on_proc_mesh( self, @@ -129,12 +102,6 @@ def __enter__(self): self._log_dir = self.cleanup.enter_context( tempfile.TemporaryDirectory(prefix="rust_cached_workers.") ) - self._rust_process_cache = self.cleanup.enter_context( - ProcessCache( - logging_location=LoggingLocation.DEFAULT, - logging_dir=self._log_dir, - ) - ) end = time.time() logging.info("started process caches in {:.2f}s".format(end - start)) return self diff --git a/python/monarch/common/constants.py b/python/monarch/common/constants.py index 9e2c869d9..581f84e46 100644 --- a/python/monarch/common/constants.py +++ b/python/monarch/common/constants.py @@ -5,6 +5,3 @@ # LICENSE file in the root directory of this source tree. # pyre-strict - -SIM_MESH_CLIENT_TIMEOUT = 5 -SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL = 5 diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 7c06af957..f927e009d 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -67,7 +67,6 @@ from monarch.common.device_mesh import DeviceMesh from monarch.common.future import Future as OldFuture from monarch.common.invocation import DeviceException, RemoteException -from monarch.rust_local_mesh import _get_worker_exec_info logger: Logger = logging.getLogger(__name__) @@ -116,8 +115,6 @@ def stop_mesh(self): def _initialize_env(worker_point: Point, proc_id: str) -> None: worker_rank = worker_point.rank try: - _, worker_env = _get_worker_exec_info() - if "gpus" in worker_point: local_rank = worker_point["gpus"] gpus_per_host = worker_point.size("gpus") @@ -130,7 +127,6 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None: num_worker_procs = worker_point.extent.nelements process_env = { - **worker_env, "CUDA_VISIBLE_DEVICES": str(local_rank), "NCCL_HOSTID": f"{proc_id}_host_{worker_rank // gpus_per_host}", # This is needed to avoid a hard failure in ncclx when we do not diff --git a/python/monarch/rust_local_mesh.py b/python/monarch/rust_local_mesh.py deleted file mode 100644 index e783d3627..000000000 --- a/python/monarch/rust_local_mesh.py +++ /dev/null @@ -1,1401 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import contextlib -import importlib.resources -import logging -import os -import random -import re -import select -import socket -import string -import subprocess -import sys -import tempfile -import threading -import time -import uuid -from enum import Enum -from pathlib import Path -from types import TracebackType -from typing import ( - Callable, - Collection, - Dict, - Generator, - List, - NamedTuple, - Optional, - TextIO, - Tuple, - Type, - TypeVar, -) - -from monarch._rust_bindings.controller.bootstrap import ( - ControllerCommand, - ControllerServerRequest, - ControllerServerResponse, - RunCommand, -) - -from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension - ActorId, -) - -from monarch._rust_bindings.monarch_tensor_worker.bootstrap import ( - WorkerServerRequest, - WorkerServerResponse, -) - -from monarch.common.device_mesh import DeviceMesh -from monarch.common.fake import fake_call -from monarch.common.invocation import DeviceException, RemoteException -from monarch.rust_backend_mesh import ( - IBootstrap, - MeshWorld, - PoolDeviceMeshProvider, - rust_backend_mesh_provider, - rust_backend_meshes, -) - -logger: logging.Logger = logging.getLogger(__name__) -_MONARCH_TENSOR_WORKER_MAIN = "monarch.tensor_worker_main" - -try: - from __manifest__ import fbmake # noqa - - IN_PAR = bool(fbmake.get("par_style")) -except ImportError: - IN_PAR = False - - -class SocketType(Enum): - """Enum representing socket types.""" - - TCP = "tcp" - UNIX = "unix" - - -class LoggingLocation(Enum): - """Enum representing where to flush stderr and stdout.""" - - DEFAULT = "default" - FILE = "file" - - -class SupervisionParams(NamedTuple): - # If system actor does not receive supervision update within this time, - # it will treate this proc as dead. - update_timeout_in_sec: int - # How often controller queries supervision status from system actor. - query_interval_in_sec: int - # How often proc actor sends supervision update to system actor. - update_interval_in_sec: int - - -class ControllerParams(NamedTuple): - # How often the controller will poll for operations that have not completed within a timeout duration - # indicating that it may be stuck. - worker_progress_check_interval_in_sec: int - - # How long we will wait for an operation before letting the client know that it may be stuck. - operation_timeout_in_sec: int - - # The number of operations invoked before we proactively check worker progress. If a large number - # of operations are invoked all at once, it is expected that it will take a while for all operations - # to complete so we want to inject progress requests at a higher frequency to check if we are making progress - operations_per_worker_progress_request: int - - # If the controller should propagate a failure to the client if the workers become stuck. - fail_on_worker_timeout: bool - - -_PROC_ENV: dict[str, str] = {} - - -def get_controller_main() -> tuple[Path, dict[str, str]]: - with ( - importlib.resources.as_file( - importlib.resources.files("monarch") / "monarch_controller" - ) as controller_main, - ): - if not controller_main.exists(): - if IN_PAR: - raise ImportError( - "Monarch env not found, please define a custom 'monarch_env' or " - "add '//monarch/python/monarch:default_env-library' to your binary dependencies " - "in TARGETS" - ) - else: - raise ImportError( - "Monarch env not found, please re-run ./scripts/install.sh in fbcode/monarch" - ) - env: dict[str, str] = {} - - # Hack to make exploded wheel workflow work in the face of broken - # build-time RPATHs... - # - # If we're running under a conda env... - if not IN_PAR: - conda_prefix = os.environ.get("CONDA_PREFIX") - if conda_prefix is not None and sys.executable.startswith( - conda_prefix + "/" - ): - # .. and Monarch is coming from "outside" the env, via `PYTHONPATH`s ... - spec = importlib.util.find_spec("monarch") - assert spec is not None - origin = spec.origin - assert origin is not None - monarch_root = str(Path(origin).parent.parent) - if ( - not monarch_root.startswith(conda_prefix + "/") - and monarch_root in sys.path - ): - import torch - - # then assume we're running via exploded .whl, which means - # we need to manually set library paths to find the necessary - # native libs from the conda env. - env["LD_LIBRARY_PATH"] = ":".join( - [ - os.path.join(os.path.dirname(torch.__file__), "lib"), - os.path.join(conda_prefix, "lib"), - ] - ) - - return controller_main, env - - -def _create_logging_locations( - logging_dir: str, name: str, logging_location: LoggingLocation -) -> tuple[TextIO | None, TextIO | None]: - if logging_location == LoggingLocation.FILE: - stdout_file: TextIO = open(os.path.join(logging_dir, f"{name}.stdout"), "a+") - stderr_file: TextIO = open(os.path.join(logging_dir, f"{name}.stderr"), "a+") - return stdout_file, stderr_file - elif logging_location == LoggingLocation.DEFAULT: - return None, None - else: - raise ValueError(f"Unknown logging location: {logging_location}") - - -def _get_labels(flag_name: str, labels: Dict[str, str]) -> List[str]: - params = [] - for k, v in labels.items(): - assert k not in params, f"Duplicate label: {k}" - assert "=" not in k, f"Key cannot contain '=': {k}" - params.append(f"--{flag_name}") - params.append(f"{k}={v}") - return params - - -def _start_worker_cmd( - *, - world_uuid: str, - worker_rank: int, - gpus_per_host: int, - num_worker_procs: int, - args: list[str], - env: dict[str, str] | None = None, - stdout: TextIO | None = None, - stderr: TextIO | None = None, - stdin: TextIO | int | None = subprocess.DEVNULL, - pass_fds: Collection[int] = (), -) -> subprocess.Popen[bytes]: - worker_cmd, worker_env = _get_worker_exec_info() - local_rank = worker_rank % gpus_per_host - process_env = { - **(_PROC_ENV | worker_env), - "CUDA_VISIBLE_DEVICES": str(local_rank), - "NCCL_HOSTID": f"{world_uuid}_host_{worker_rank // gpus_per_host}", - # This is needed to avoid a hard failure in ncclx when we do not - # have backend topology info (eg. on RE). - "NCCL_IGNORE_TOPO_LOAD_FAILURE": "true", - "LOCAL_RANK": str(local_rank), - "RANK": str(worker_rank), - "WORLD_SIZE": str(num_worker_procs), - "LOCAL_WORLD_SIZE": str(gpus_per_host), - **os.environ, - } - cmd = [] - cmd.extend(worker_cmd) - cmd.extend(args) - if env is not None: - process_env.update(env) - return subprocess.Popen( - cmd, - env=process_env, - stdout=stdout, - stderr=stderr, - stdin=stdin, - pass_fds=pass_fds, - ) - - -ServerT = TypeVar("ServerT") - - -class ServerInstance: - TIMEOUT = 10.0 - - def __init__( - self, - *, - server: "ServerBase[ServerT]", - ) -> None: - self._server = server - self._terminated: float = 0 - - # TODO - assert self._server._proc is not None - self.pid: int = self._server._proc.pid - - def __enter__(self) -> "ServerInstance": - return self - - def terminate(self) -> None: - # Start the timeout clock now. - self._terminated = time.time() - - def kill(self) -> None: - pass - - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - timeout = max(0, self._terminated + self.TIMEOUT - time.time()) - try: - self._server._finish(timeout=timeout) - except Exception as exc: - if exc_type is None: - raise - else: - logger.warning(f"failed waiting for instance to finish: {exc}") - - -class ServerBase(contextlib.AbstractContextManager[ServerT, None]): - def __init__( - self, - *, - name: str, - response_cls: Type[WorkerServerResponse | ControllerServerResponse], - request_cls: Type[WorkerServerRequest | ControllerServerRequest], - ) -> None: - self._name = name - self._response_cls: Type[WorkerServerResponse | ControllerServerResponse] = ( - response_cls - ) - self._request_cls: Type[WorkerServerRequest | ControllerServerRequest] = ( - request_cls - ) - - self._aborted = False - self._shutdown_started = False - self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack() - self._proc: subprocess.Popen[bytes] | None = None - self._pipe: Tuple[TextIO, TextIO] | None = None - self._lock: threading.Lock | None = None - - def _send(self, msg: WorkerServerRequest | ControllerServerRequest) -> None: - logger.debug(f"{self._name}: sending server request: {msg}") - assert not self._aborted - assert self._lock is not None - if not self._lock.acquire(blocking=False): - raise Exception("server in use") - assert self._pipe is not None - self._pipe[1].write(msg.to_json() + "\n") - assert self._pipe is not None - self._pipe[1].flush() - - def _recv( - self, timeout: float | None = None - ) -> WorkerServerResponse | ControllerServerResponse: - assert not self._aborted - assert self._lock is not None - assert self._lock.locked() - assert self._pipe is not None - ready, _, _ = select.select([self._pipe[0]], [], [], timeout) - if not ready: - assert self._proc is not None - assert timeout is not None - raise subprocess.TimeoutExpired(self._proc.args, timeout) - output = ready[0].readline() - logger.info(f"{self._name}: Got response: {output}") - response = self._response_cls.from_json(output) - assert self._lock is not None - self._lock.release() - logger.debug(f"{self._name}: received response: {response}") - return response - - def _launch_server( - self, - read_fd: int, - write_fd: int, - ) -> subprocess.Popen[bytes]: - raise NotImplementedError() - - def __enter__(self) -> ServerT: - assert self._proc is None, "already running" - logger.debug(f"{self._name}: launching worker server") - self._lock = threading.Lock() - send = os.pipe2(0) - recv = os.pipe2(0) - self._proc = self._contexts.enter_context( - self._launch_server( - read_fd=send[0], - write_fd=recv[1], - ), - ) - self._pipe = ( - self._contexts.enter_context(os.fdopen(recv[0], "r")), - self._contexts.enter_context(os.fdopen(send[1], "w")), - ) - os.close(send[0]) - os.close(recv[1]) - # pyre-ignore: Incompatible return type [7] - return self - - def initiate_shutdown(self) -> None: - if not self._shutdown_started and not self._aborted: - assert self._lock is not None - assert not self._lock.locked() - self._shutdown_started = True - self._send(self._request_cls.Exit()) - assert self._pipe is not None - self._pipe[1].close() - - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - if exc_type is not None or self._aborted: - assert self._proc is not None - self._proc.kill() - else: - # attempt a clean shutdown - self.initiate_shutdown() - assert self._proc is not None - assert self._proc.wait(timeout=5) == 0 - self._contexts.__exit__(exc_type, exc_val, exc_tb) - - def _finish(self, timeout: float | None = None) -> None: - try: - response = self._recv(timeout=timeout) - assert isinstance(response, self._response_cls.Finished), str(response) - # pyre-ignore: Undefined attribute [16] - assert response.error is None, response.error - except: - self._aborted = True - raise - - def _launch_instance( - self, - *, - msg: WorkerServerRequest | ControllerServerRequest, - ) -> ServerInstance: - self._send(msg) - return ServerInstance(server=self) - - -class ISystemFactory: - def launch( - self, - *, - bootstrap_addr: str, - supervision_params: SupervisionParams, - ) -> ServerInstance | subprocess.Popen[bytes]: - raise NotImplementedError() - - -class IControllerFactory: - def launch( - self, - *, - worker_world: str, - bootstrap_addr: str, - controller_id: ActorId, - num_worker_procs: int, - gpus_per_host: int, - supervision_params: SupervisionParams, - controller_params: ControllerParams, - labels: Dict[str, str], - ) -> subprocess.Popen[bytes] | ServerInstance: - raise NotImplementedError() - - -class ControllerFactoryBase: - def __init__( - self, - *, - logging_location: LoggingLocation, - logging_dir: str, - ) -> None: - self.logging_location = logging_location - self.logging_dir = logging_dir - - self.controller_main: Path - self.controller_env: dict[str, str] - self.controller_main, self.controller_env = get_controller_main() - - -class SystemFactory(ControllerFactoryBase, ISystemFactory): - def launch( - self, - *, - bootstrap_addr: str, - supervision_params: SupervisionParams, - ) -> subprocess.Popen[bytes]: - stdout_location, stderr_location = _create_logging_locations( - self.logging_dir, - "system", - self.logging_location, - ) - return subprocess.Popen( - [ - self.controller_main, - "system", - "--system-addr", - bootstrap_addr, - "--supervision-update-timeout-in-sec", - str(supervision_params.update_timeout_in_sec), - ], - stdout=stdout_location, - stderr=stderr_location, - stdin=subprocess.DEVNULL, - env=_PROC_ENV | self.controller_env, - ) - - -class ControllerFactory(ControllerFactoryBase, IControllerFactory): - def launch( - self, - *, - worker_world: str, - bootstrap_addr: str, - controller_id: ActorId, - num_worker_procs: int, - gpus_per_host: int, - supervision_params: SupervisionParams, - controller_params: ControllerParams, - labels: Dict[str, str], - ) -> subprocess.Popen[bytes]: - stdout_location, stderr_location = _create_logging_locations( - self.logging_dir, - controller_id.world_name, - self.logging_location, - ) - command = [ - self.controller_main, - "controller", - "--worker-world", - worker_world, - "--system-addr", - bootstrap_addr, - "--controller-actor-id", - str(controller_id), - "--world-size", - str(num_worker_procs), - "--num-procs-per-host", - str(gpus_per_host), - "--supervision-query-interval-in-sec", - str(supervision_params.query_interval_in_sec), - "--supervision-update-interval-in-sec", - str(supervision_params.update_interval_in_sec), - "--worker-progress-check-interval-in-sec", - str(controller_params.worker_progress_check_interval_in_sec), - "--operation-timeout-in-sec", - str(controller_params.operation_timeout_in_sec), - "--operations-per-worker-progress-request", - str(controller_params.operations_per_worker_progress_request), - ] - - if controller_params.fail_on_worker_timeout: - command.append("--fail-on-worker-timeout") - - return subprocess.Popen( - command + _get_labels("extra-proc-labels", labels), - stdout=stdout_location, - stderr=stderr_location, - stdin=subprocess.DEVNULL, - env=_PROC_ENV | self.controller_env, - ) - - -class ControllerServerBase(ServerBase[ServerT]): - def __init__( - self, - *, - uuid: str, - logging_location: LoggingLocation, - logging_dir: str, - ) -> None: - super().__init__( - name=uuid, - response_cls=ControllerServerResponse, - request_cls=ControllerServerRequest, - ) - self.uuid = uuid - self.logging_location = logging_location - self.logging_dir = logging_dir - - self.controller_main: Path - self.controller_env: dict[str, str] - self.controller_main, self.controller_env = get_controller_main() - - def _launch_server( - self, - read_fd: int, - write_fd: int, - ) -> subprocess.Popen[bytes]: - stdout_location, stderr_location = _create_logging_locations( - self.logging_dir, - self.uuid, - self.logging_location, - ) - return subprocess.Popen( - [ - self.controller_main, - "serve", - str(read_fd), - str(write_fd), - ], - stdout=stdout_location, - pass_fds=(read_fd, write_fd), - stderr=stderr_location, - stdin=subprocess.DEVNULL, - env=_PROC_ENV | self.controller_env | dict(os.environ), - ) - - -class SystemServer(ControllerServerBase["SystemServer"], ISystemFactory): - def launch( - self, - *, - bootstrap_addr: str, - supervision_params: SupervisionParams, - ) -> ServerInstance: - return self._launch_instance( - msg=ControllerServerRequest.Run( - RunCommand.System( - system_addr=bootstrap_addr, - supervision_update_timeout_in_sec=supervision_params.update_timeout_in_sec, - world_eviction_timeout_in_sec=10, - ), - ), - ) - - -class ControllerServer(ControllerServerBase["ControllerServer"], IControllerFactory): - def launch( - self, - *, - worker_world: str, - bootstrap_addr: str, - controller_id: ActorId, - num_worker_procs: int, - gpus_per_host: int, - supervision_params: SupervisionParams, - controller_params: ControllerParams, - labels: Dict[str, str], - ) -> ServerInstance: - return self._launch_instance( - msg=ControllerServerRequest.Run( - RunCommand.Controller( - ControllerCommand( - worker_world=worker_world, - system_addr=bootstrap_addr, - controller_actor_id=str(controller_id), - world_size=num_worker_procs, - num_procs_per_host=gpus_per_host, - worker_name="worker", - program=None, - supervision_query_interval_in_sec=supervision_params.query_interval_in_sec, - supervision_update_interval_in_sec=supervision_params.update_interval_in_sec, - worker_progress_check_interval_in_sec=controller_params.worker_progress_check_interval_in_sec, - operation_timeout_in_sec=controller_params.operation_timeout_in_sec, - operations_per_worker_progress_request=controller_params.operations_per_worker_progress_request, - fail_on_worker_timeout=controller_params.fail_on_worker_timeout, - is_cpu_worker=False, - extra_proc_labels=list(labels.items()), - ), - ), - ), - ) - - -class IWorkerFactory: - def launch( - self, - *, - worker_world: str, - worker_rank: int, - bootstrap_addr: str, - labels: Dict[str, str], - ) -> ServerInstance | subprocess.Popen[bytes]: - raise NotImplementedError() - - -class WorkerFactory(IWorkerFactory): - def __init__( - self, - *, - num_worker_procs: int, - gpus_per_host: int, - logging_location: LoggingLocation, - logging_dir: str, - ) -> None: - self.num_worker_procs = num_worker_procs - self.gpus_per_host = gpus_per_host - self.logging_location = logging_location - self.logging_dir = logging_dir - - def launch( - self, - *, - worker_world: str, - worker_rank: int, - bootstrap_addr: str, - labels: Dict[str, str], - ) -> subprocess.Popen[bytes]: - stdout_location, stderr_location = _create_logging_locations( - self.logging_dir, - f"{worker_world}_{worker_rank}", - self.logging_location, - ) - return _start_worker_cmd( - world_uuid=worker_world, - worker_rank=worker_rank, - gpus_per_host=self.gpus_per_host, - num_worker_procs=self.num_worker_procs, - args=[ - "worker", - "--world-id", - worker_world, - "--proc-id", - f"{worker_world}[{worker_rank}]", - "--bootstrap-addr", - bootstrap_addr, - ] - + _get_labels("extra-proc-labels", labels), - stdout=stdout_location, - stderr=stderr_location, - ) - - -class WorkerServer(ServerBase["WorkerServer"]): - def __init__( - self, - *, - uuid: str, - num_worker_procs: int, - gpus_per_host: int, - world_rank: int, - logging_location: LoggingLocation, - logging_dir: str, - ) -> None: - super().__init__( - name=uuid, - response_cls=WorkerServerResponse, - request_cls=WorkerServerRequest, - ) - self.uuid = uuid - self.num_worker_procs = num_worker_procs - self.gpus_per_host = gpus_per_host - self.world_rank = world_rank - self.logging_location = logging_location - self.logging_dir = logging_dir - - def _launch_server( - self, - read_fd: int, - write_fd: int, - ) -> subprocess.Popen[bytes]: - stdout_location, stderr_location = _create_logging_locations( - self.logging_dir, - f"{self.uuid}_{self.world_rank}", - self.logging_location, - ) - return _start_worker_cmd( - world_uuid=self.uuid, - worker_rank=self.world_rank, - gpus_per_host=self.gpus_per_host, - num_worker_procs=self.num_worker_procs, - args=["worker-server", str(read_fd), str(write_fd)], - pass_fds=(read_fd, write_fd), - stdin=subprocess.PIPE, - stdout=stdout_location, - stderr=stderr_location, - ) - - def launch( - self, - *, - worker_world: str, - bootstrap_addr: str, - labels: Dict[str, str], - ) -> ServerInstance: - return self._launch_instance( - msg=WorkerServerRequest.Run( - world_id=worker_world, - proc_id=f"{worker_world}[{self.world_rank}]", - bootstrap_addr=bootstrap_addr, - labels=list(labels.items()), - ) - ) - - -class WorkerServers(IWorkerFactory): - def __init__( - self, - *, - workers: dict[int, WorkerServer], - ) -> None: - self._workers = workers - self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack() - - @staticmethod - def create( - uuid: str, - num_worker_procs: int, - gpus_per_host: int, - logging_location: LoggingLocation, - logging_dir: str, - ) -> "WorkerServers": - return WorkerServers( - workers={ - world_rank: WorkerServer( - uuid=uuid, - num_worker_procs=num_worker_procs, - gpus_per_host=gpus_per_host, - world_rank=world_rank, - logging_location=logging_location, - logging_dir=logging_dir, - ) - for world_rank in range(num_worker_procs) - }, - ) - - def initiate_shutdown(self) -> None: - for worker in self._workers.values(): - worker.initiate_shutdown() - - def __enter__(self) -> "WorkerServers": - for worker in self._workers.values(): - self._contexts.enter_context(worker) - return self - - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - self.initiate_shutdown() - self._contexts.__exit__(exc_type, exc_val, exc_tb) - - def launch( - self, - *, - worker_world: str, - worker_rank: int, - bootstrap_addr: str, - labels: Dict[str, str], - ) -> ServerInstance | subprocess.Popen[bytes]: - return self._workers[worker_rank].launch( - worker_world=worker_world, - bootstrap_addr=bootstrap_addr, - labels=labels, - ) - - -class ProcessCache: - def __init__( - self, - *, - logging_location: LoggingLocation, - logging_dir: str, - ) -> None: - self.logging_location: LoggingLocation = logging_location - self.logging_dir: str = logging_dir - - self._system_cache: SystemServer | None = None - self._controller_cache: ControllerServer | None = None - self._worker_cache: dict[Tuple[int, int], WorkerServers] = {} - self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack() - - def __enter__(self) -> "ProcessCache": - return self - - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - if self._system_cache is not None: - self._system_cache.initiate_shutdown() - if self._controller_cache is not None: - self._controller_cache.initiate_shutdown() - for workers in self._worker_cache.values(): - workers.initiate_shutdown() - self._contexts.__exit__(exc_type, exc_val, exc_tb) - - def get_system_server(self) -> SystemServer: - if self._system_cache is None: - system = SystemServer( - uuid="cached_system", - logging_location=self.logging_location, - logging_dir=self.logging_dir, - ) - self._system_cache = self._contexts.enter_context(system) - assert self._system_cache is not None - return self._system_cache - - def get_controller_server(self) -> ControllerServer: - if self._controller_cache is None: - controller = ControllerServer( - uuid="cached_controller", - logging_location=self.logging_location, - logging_dir=self.logging_dir, - ) - self._controller_cache = self._contexts.enter_context(controller) - assert self._controller_cache is not None - return self._controller_cache - - def get_worker_servers( - self, - *, - num_worker_procs: int, - gpus_per_host: int, - ) -> WorkerServers: - key = (num_worker_procs, gpus_per_host) - workers = self._worker_cache.get(key) - if workers is None: - workers = WorkerServers.create( - uuid=f"cached_workers_{num_worker_procs}_{gpus_per_host}", - num_worker_procs=num_worker_procs, - gpus_per_host=gpus_per_host, - logging_location=self.logging_location, - logging_dir=self.logging_dir, - ) - self._worker_cache[key] = self._contexts.enter_context(workers) - return workers - - -class Bootstrap: - def __init__( - self, - *, - meshes: int, - hosts_per_mesh: int, - gpus_per_host: int, - worker_factory: IWorkerFactory | None = None, - controller_factory: IControllerFactory | None = None, - system_factory: ISystemFactory | None = None, - socket_type: SocketType, - logging_location: LoggingLocation, - supervision_params: SupervisionParams | None, - controller_params: ControllerParams | None, - auto_epoch: bool, - controller_labels: Dict[str, str] | None = None, - worker_labels: Dict[str, str] | None = None, - ) -> None: - if supervision_params is None: - supervision_params = SupervisionParams( - update_timeout_in_sec=20, - query_interval_in_sec=2, - update_interval_in_sec=2, - ) - self.supervision_params: SupervisionParams = supervision_params - - if controller_params is None: - controller_params = ControllerParams( - worker_progress_check_interval_in_sec=10, - operation_timeout_in_sec=120, - operations_per_worker_progress_request=100, - fail_on_worker_timeout=False, - ) - self.controller_params: ControllerParams = controller_params - - self.epoch: int | None = 0 if auto_epoch else None - - # hyperactor_telemetry will take the execution id and use it across all processes - execution_id = "rust_local_" + uuid.uuid4().hex - os.environ["HYPERACTOR_EXECUTION_ID"] = execution_id - - # Create a temporary directory for logging - self.logging_dir: str = ( - tempfile.mkdtemp(prefix="rust_local_mesh_") - if logging_location == LoggingLocation.FILE - else "N/A" - ) - logger.info( - f"Creating Rust local mesh with {meshes} meshes X {hosts_per_mesh} hosts X {gpus_per_host} gpus.\n" - f"Logging directory: \033[92;1m{self.logging_dir}\033[0m\n" - f"Execution id: {execution_id}" - ) - self.logging_location: LoggingLocation = logging_location - - if controller_factory is None: - controller_factory = ControllerFactory( - logging_location=self.logging_location, - logging_dir=self.logging_dir, - ) - self.controller_factory: IControllerFactory = controller_factory - - if system_factory is None: - system_factory = SystemFactory( - logging_location=self.logging_location, - logging_dir=self.logging_dir, - ) - self.system_factory: ISystemFactory = system_factory - - # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later - if worker_factory is None: - worker_factory = WorkerFactory( - num_worker_procs=hosts_per_mesh * gpus_per_host, - gpus_per_host=gpus_per_host, - logging_location=self.logging_location, - logging_dir=self.logging_dir, - ) - self.worker_factory: IWorkerFactory = worker_factory - - # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later - fake_call(lambda: 0) - - self.bootstrap_addr: str - if socket_type == SocketType.TCP: - with socket.socket() as sock: - sock.bind(("", 0)) - port = sock.getsockname()[1] - self.bootstrap_addr = f"tcp![::1]:{port}" - elif socket_type == SocketType.UNIX: - # provide a random unix socket address - self.bootstrap_addr: str = f"unix!@{''.join(random.choice(string.ascii_lowercase) for _ in range(14))}-system" - else: - raise ValueError(f"Unknown socket type: {socket_type}") - - env = os.environ.copy() - self.env: dict[str, str] = env - - # Launch a single system globally - self.processes: list[subprocess.Popen[bytes] | ServerInstance] = [] - self.processes.append(self._launch_system()) - - self.has_shutdown: bool = False - self.gpus_per_host: int = gpus_per_host - self.num_worker_procs: int = hosts_per_mesh * gpus_per_host - self.controller_ids: list[ActorId] = [] - self.mesh_worlds: dict[ - MeshWorld, list[subprocess.Popen[bytes] | ServerInstance] - ] = {} - - # Create meshes, each of which contains a single controller and multiple workers. - # All of them will connect to the same system. - pids: dict[str, list[int]] = {} - for i in range(meshes): - mesh_name: str = f"mesh_{i}" - controller_world: str = f"{mesh_name}_controller" - worker_world: str = f"{mesh_name}_worker" - controller_id: ActorId = ActorId( - world_name=controller_world, - rank=0, - actor_name="controller", - ) - self.mesh_worlds[(worker_world, controller_id)] = [] - self.controller_ids.append(controller_id) - - processes: list[subprocess.Popen[bytes] | ServerInstance] = ( - self.launch_mesh( - controller_id, - worker_world, - controller_labels=controller_labels, - worker_labels=worker_labels, - ) - ) - - self.processes.extend(processes) - pids[mesh_name] = [p.pid for p in processes] - - log_message = ( - f"All processes started successfully:\n system: {self.processes[0].pid}\n" - ) - for mesh, procs in pids.items(): - log_message += f"{mesh}: controller: {procs[0]}, " - worker_messages = [] - for i in range(1, len(procs)): - worker_messages.append(f"{i-1}: {procs[i]}") - log_message += "workers: " + ", ".join(worker_messages) - log_message += "\n" - - self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack() - - logger.info(log_message) - - def _launch_system( - self, - ) -> ServerInstance | subprocess.Popen[bytes]: - logger.info("launching system") - try: - return self.system_factory.launch( - bootstrap_addr=self.bootstrap_addr, - supervision_params=self.supervision_params, - ) - except Exception as e: - logger.error(f"Failed to start system process: {e}") - raise e - - def _launch_controller( - self, - controller_id: ActorId, - worker_world: str, - epoch: str | None = None, - labels: Dict[str, str] | None = None, - ) -> subprocess.Popen[bytes] | ServerInstance: - logger.info("launching controller") - try: - return self.controller_factory.launch( - bootstrap_addr=self.bootstrap_addr, - worker_world=worker_world - if epoch is None - else f"{worker_world}_{epoch}", - controller_id=ActorId.from_string( - ( - f"{controller_id.world_name + '_' + epoch if epoch else controller_id.world_name}" - f"[{controller_id.rank}]." - f"{controller_id.actor_name}[{controller_id.pid}]" - ) - ), - num_worker_procs=self.num_worker_procs, - gpus_per_host=self.gpus_per_host, - supervision_params=self.supervision_params, - controller_params=self.controller_params, - labels={} if labels is None else labels, - ) - except Exception as e: - logger.error(f"Failed to start controller process: {e}") - raise e - - def _launch_worker( - self, - worker_world: str, - worker_rank: int, - epoch: str | None = None, - labels: Dict[str, str] | None = None, - ) -> subprocess.Popen[bytes] | ServerInstance: - logger.info("launching worker") - try: - return self.worker_factory.launch( - worker_world=worker_world - if epoch is None - else f"{worker_world}_{epoch}", - worker_rank=worker_rank, - bootstrap_addr=self.bootstrap_addr, - labels={} if labels is None else labels, - ) - except Exception as e: - logger.error(f"Failed to start worker process {worker_rank}: {e}") - raise e - - def get_mesh_worlds(self) -> list[MeshWorld]: - return list(self.mesh_worlds.keys()) - - def kill_mesh(self, mesh_world: MeshWorld) -> None: - logger.info(f"Killing mesh {mesh_world}") - procs = self.mesh_worlds[mesh_world] - procs[-1].kill() - - def spawn_mesh(self, mesh_world: MeshWorld) -> None: - self.launch_mesh(mesh_world[1], mesh_world[0]) - - def launch_mesh( - self, - controller_id: ActorId, - worker_world: str, - controller_labels: Dict[str, str] | None = None, - worker_labels: Dict[str, str] | None = None, - ) -> list[subprocess.Popen[bytes] | ServerInstance]: - """ - Create a single controller and multiple workers for a mesh. - The first process of the return is the controller. - The remaining ones are workers. - """ - logger.info( - f"Launching mesh {worker_world} with controller {controller_id} epoch {self.epoch}" - ) - epoch: str | None = None - if self.epoch is not None: - epoch = f"epoch_{self.epoch}" - self.epoch += 1 - - processes: list[subprocess.Popen[bytes] | ServerInstance] = [] - controller_process = self._launch_controller( - controller_id, - worker_world, - epoch, - controller_labels, - ) - processes.append(controller_process) - - for i in range(self.num_worker_procs): - worker_process = self._launch_worker(worker_world, i, epoch, worker_labels) - processes.append(worker_process) - self.mesh_worlds[(worker_world, controller_id)] = processes - return processes - - def __enter__(self) -> "Bootstrap": - for process in self.processes: - self._contexts.enter_context(process) - return self - - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - for process in self.processes: - process.terminate() - self._contexts.__exit__(exc_type, exc_val, exc_tb) - - -def _local_device_count() -> int: - dev_path = Path("/dev") - pattern = re.compile(r"nvidia\d+$") - nvidia_devices = [dev for dev in dev_path.iterdir() if pattern.match(dev.name)] - return len(nvidia_devices) - - -def _get_worker_exec_info() -> tuple[list[str], dict[str, str]]: - if IN_PAR: - cmd = [sys.argv[0]] - env = { - "PAR_MAIN_OVERRIDE": _MONARCH_TENSOR_WORKER_MAIN, - } - else: - cmd = [sys.executable, "-m", _MONARCH_TENSOR_WORKER_MAIN] - env = {} - - env["MONARCH_TENSOR_WORKER_MAIN"] = _MONARCH_TENSOR_WORKER_MAIN - env["MONARCH_TENSOR_WORKER_EXE"] = cmd[0] - return cmd, env - - -@contextlib.contextmanager -def local_mesh( - *, - hosts: int = 1, - gpus_per_host: int | None = None, - socket_type: SocketType = SocketType.TCP, - logging_location: LoggingLocation = LoggingLocation.FILE, - supervision_params: SupervisionParams | None = None, - controller_params: ControllerParams | None = None, - worker_factory: IWorkerFactory | None = None, - controller_factory: IControllerFactory | None = None, - system_factory: ISystemFactory | None = None, -) -> Generator[DeviceMesh, None, None]: - """ - Creates a single local device mesh with the given number of per host. - - Args: - hosts : number of hosts, primarily used for simulating multiple machines locally. - Default: 1 - gpus_per_host : number of gpus per host. - Default: the number of GPUs this machine has. - socket_type : socket type to use for communication between processes. - Default: TCP. - - Example:: - with local_mesh().activate(): - x = torch.rand(3, 4) - local_tensor = fetch_shard(x).result() - """ - with local_meshes( - meshes=1, - hosts_per_mesh=hosts, - gpus_per_host=gpus_per_host, - socket_type=socket_type, - logging_location=logging_location, - supervision_params=supervision_params, - controller_params=controller_params, - worker_factory=worker_factory, - controller_factory=controller_factory, - system_factory=system_factory, - ) as dms: - assert len(dms) == 1 - yield dms[0] - - -@contextlib.contextmanager -def local_meshes( - *, - meshes: int = 1, - hosts_per_mesh: int = 1, - gpus_per_host: int | None = None, - socket_type: SocketType = SocketType.TCP, - logging_location: LoggingLocation = LoggingLocation.FILE, - supervision_params: SupervisionParams | None = None, - controller_params: ControllerParams | None = None, - worker_factory: IWorkerFactory | None = None, - controller_factory: IControllerFactory | None = None, - system_factory: ISystemFactory | None = None, -) -> Generator[list[DeviceMesh], None, None]: - """ - Creates multiple local device meshes. - - Args: - meshes : number of global meshes to create. - Default: 1 - hosts_per_mesh : number of hosts per mesh, primarily used for simulating multiple machines locally. - Default: 1 - gpus_per_host : number of gpus per host. - Default: the number of GPUs this machine has. - socket_type : socket type to use for communication between processes. - Default: TCP. - """ - (dms, bootstrap) = local_meshes_and_bootstraps( - meshes=meshes, - hosts_per_mesh=hosts_per_mesh, - gpus_per_host=gpus_per_host, - socket_type=socket_type, - logging_location=logging_location, - supervision_params=supervision_params, - controller_params=controller_params, - worker_factory=worker_factory, - controller_factory=controller_factory, - system_factory=system_factory, - ) - with bootstrap: - maybe_error = None - try: - yield dms - except Exception as e: - maybe_error = e - raise - finally: - for dm in dms: - dm.exit(maybe_error) - - -def local_meshes_and_bootstraps( - *, - meshes: int = 1, - hosts_per_mesh: int = 1, - gpus_per_host: int | None = None, - socket_type: SocketType = SocketType.TCP, - logging_location: LoggingLocation = LoggingLocation.FILE, - supervision_params: SupervisionParams | None = None, - controller_params: ControllerParams | None = None, - auto_epoch: bool = False, - worker_factory: IWorkerFactory | None = None, - controller_factory: IControllerFactory | None = None, - system_factory: ISystemFactory | None = None, -) -> tuple[list[DeviceMesh], Bootstrap]: - """ - Same as local_meshes, but also returns the bootstrap object. This is - useful in tests where we want to maniputate the bootstrap object. - """ - - if gpus_per_host is None: - gpus_per_host = _local_device_count() - assert ( - 0 < gpus_per_host <= 8 - ), "Number of GPUs must be greater than 0 and at most 8." - bootstrap: Bootstrap = Bootstrap( - meshes=meshes, - hosts_per_mesh=hosts_per_mesh, - gpus_per_host=gpus_per_host, - socket_type=socket_type, - logging_location=logging_location, - supervision_params=supervision_params, - controller_params=controller_params, - auto_epoch=auto_epoch, - worker_factory=worker_factory, - controller_factory=controller_factory, - system_factory=system_factory, - ) - - def create_exit( - dm: DeviceMesh, bootstrap: Bootstrap - ) -> Callable[[Optional[RemoteException | DeviceException | Exception]], None]: - def exit( - error: Optional[RemoteException | DeviceException | Exception] = None, - ) -> None: - # We only support one single client proc. - if not bootstrap.has_shutdown: - dm.client.shutdown(True, error) - bootstrap.has_shutdown = True - - # We do not need to shutdown bootstrap and clean up the processes - # as they will be cleaned up with the parent. - return exit - - dms = rust_backend_meshes( - system_addr=bootstrap.bootstrap_addr, - hosts=hosts_per_mesh, - gpus=gpus_per_host, - requested_meshes=meshes, - ) - - for dm in dms: - dm.exit = create_exit(dm, bootstrap) - - return (dms, bootstrap) - - -def local_mesh_provider( - *, - meshes: int = 1, - hosts_per_mesh: int = 1, - gpus_per_host: int | None = None, - socket_type: SocketType = SocketType.TCP, - logging_location: LoggingLocation = LoggingLocation.FILE, - supervision_params: SupervisionParams | None = None, - controller_params: ControllerParams | None = None, - auto_epoch: bool = False, - controller_labels: Dict[str, str] | None = None, - worker_labels: Dict[str, str] | None = None, - worker_factory: IWorkerFactory | None = None, - controller_factory: IControllerFactory | None = None, - system_factory: ISystemFactory | None = None, - # pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type. -) -> tuple[PoolDeviceMeshProvider, Bootstrap]: - if gpus_per_host is None: - gpus_per_host = _local_device_count() - assert ( - 0 < gpus_per_host <= 8 - ), "Number of GPUs must be greater than 0 and at most 8." - bootstrap: Bootstrap = Bootstrap( - meshes=meshes, - hosts_per_mesh=hosts_per_mesh, - gpus_per_host=gpus_per_host, - socket_type=socket_type, - logging_location=logging_location, - supervision_params=supervision_params, - controller_params=controller_params, - auto_epoch=auto_epoch, - controller_labels=controller_labels, - worker_labels=worker_labels, - worker_factory=worker_factory, - controller_factory=controller_factory, - system_factory=system_factory, - ) - - provider = rust_backend_mesh_provider( - system_addr=bootstrap.bootstrap_addr, - hosts=hosts_per_mesh, - gpus=gpus_per_host, - ) - return (provider, bootstrap) diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py deleted file mode 100644 index 1f342dad1..000000000 --- a/python/monarch/sim_mesh.py +++ /dev/null @@ -1,350 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import importlib.resources -import logging -import os -import random -import string -import subprocess -import tempfile -import time -from pathlib import Path -from typing import ( - Callable, - ContextManager as AbstractContextManager, - Dict, - Generic, - Iterable, - List, - Optional, - Tuple, -) - -from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension - ClientActor, -) - -from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension - SimulatorClient, -) - -from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension - ActorId, - init_proc, - Proc, -) - -from monarch._src.actor.shape import NDSlice -from monarch.common.client import Client -from monarch.common.constants import ( - SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL, - SIM_MESH_CLIENT_TIMEOUT, -) -from monarch.common.device_mesh import DeviceMesh -from monarch.common.fake import fake_call -from monarch.common.future import Future, T -from monarch.common.invocation import DeviceException, RemoteException -from monarch.common.messages import Dims -from monarch.controller.rust_backend.controller import RustController -from monarch.rust_backend_mesh import MeshWorld - - -logger: logging.Logger = logging.getLogger(__name__) - - -def sim_mesh(n_meshes: int, hosts: int, gpus_per_host: int) -> List[DeviceMesh]: - """ - Creates a single simulated device mesh with the given number of per host. - - Args: - n_meshes : number of device meshes to create. - hosts : number of hosts, primarily used for simulating multiple machines locally. - Default: 1 - gpus_per_host : number of gpus per host. - Default: the number of GPUs this machine has. - """ - mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = {} - bootstrap: Bootstrap = Bootstrap( - n_meshes, - mesh_world_state, - world_size=hosts * gpus_per_host, - ) - - client_proc_id = "client[0]" - client_proc: Proc = init_proc( - proc_id=client_proc_id, - bootstrap_addr=bootstrap.client_bootstrap_addr, - timeout=SIM_MESH_CLIENT_TIMEOUT, # unused - supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL, - listen_addr=bootstrap.client_listen_addr, - ) - root_client_actor: ClientActor = ClientActor( - proc=client_proc, actor_name="root_client" - ) - - dms = [] - for i in range(n_meshes): - controller_id = ActorId( - world_name=f"mesh_{i}_controller", rank=0, actor_name="root" - ) - # Create a new device mesh - backend_ctrl = RustController( - proc=client_proc, - client_actor=ClientActor(client_proc, "backend_controller"), - controller_id=controller_id, - worker_world_name=f"mesh_{i}_worker", - ) - client = Client(backend_ctrl, hosts * gpus_per_host, gpus_per_host) - dm = SimMesh( - client, - NDSlice(offset=0, sizes=[hosts, gpus_per_host], strides=[gpus_per_host, 1]), - ("host", "gpu"), - bootstrap._simulator_client, - f"mesh_{i}_worker", - ) - dms.append(dm) - - return dms - - -class OriginalFutureWrapper(Generic[T]): - result: Callable[ - [ - Future[T], - float | None, - ], - T, - ] = Future.result - _set_result: Callable[[Future[T], T], None] = Future._set_result - - -class SimMesh(DeviceMesh, Generic[T]): - def __init__( - self, - client: "Client", - processes: "NDSlice", - names: Dims, - simulator_client: SimulatorClient, - mesh_name: str = "default", - ) -> None: - super().__init__(client, processes, names, mesh_name) - self.simulator_client: SimulatorClient = simulator_client - - # monkey patch Future.result and Future._set_result to hook into set_training_script_state_{running,waiting} - def activate(self) -> AbstractContextManager[DeviceMesh]: - def sim_result(fut: Future[T], timeout: float | None = None) -> T: - self.simulator_client.set_training_script_state_waiting() - return OriginalFutureWrapper.result(fut, timeout) - - def sim_set_result(fut: Future[T], result: T) -> None: - self.simulator_client.set_training_script_state_running() - return OriginalFutureWrapper._set_result(fut, result) - - # pyre-ignore - Future.result = sim_result - Future._set_result = sim_set_result - - return super().activate() - - # restore Future.result and Future._set_result to their previous values - def exit( - self, - error: Optional[RemoteException | DeviceException | Exception] = None, - ) -> None: - self.client.shutdown(True, error) - # pyre-ignore - Future.result = OriginalFutureWrapper._result - Future._set_result = OriginalFutureWrapper._set_result - - -def _random_id(length: int = 14) -> str: - """ - A simple random id generator. - """ - return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) - - -class Bootstrap: - def __init__( - self, - num_meshes: int, - mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]], - world_size: int = 1, - ) -> None: - """ - Bootstraps a SimMesh. - Args: - num_meshes: int - number of meshes to create. - mesh_world_state: a state of the meshes. Keys are the MeshWorld and values are boolean indicating if this mesh is active. - """ - # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later - fake_call(lambda: 0) - - env = os.environ.copy() - self.env: dict[str, str] = env - - self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state - - self.bootstrap_addr: str = "sim!unix!@system" - self.client_listen_addr = "sim!unix!@client" - self.client_bootstrap_addr = "sim!unix!@client,unix!@system" - - self._simulator_client = SimulatorClient(self.bootstrap_addr, world_size) - for i in range(num_meshes): - mesh_name: str = f"mesh_{i}" - controller_world: str = f"{mesh_name}_controller" - worker_world: str = f"{mesh_name}_worker" - controller_id: ActorId = ActorId( - world_name=controller_world, - rank=0, - actor_name="root", - ) - mesh_world = (worker_world, controller_id) - self._mesh_world_state[mesh_world] = None - self.spawn_mesh(mesh_world) - # sleep for 10 sec for the worker and controller tasks to be spawned and ready. - time.sleep(10) - - def get_mesh_worlds(self) -> List[MeshWorld]: - return [] - - def kill_mesh(self, mesh_world: MeshWorld) -> None: - pass - - def spawn_mesh(self, mesh_world: MeshWorld) -> None: - worker_world, controller_id = mesh_world - controller_world = controller_id.world_name - self._simulator_client.spawn_mesh( - self.bootstrap_addr, - f"{controller_world}[0].root", - worker_world, - ) - - -def _validate_proccesses_end( - processes: Iterable[subprocess.Popen[bytes]], - timeout_in_sec: int = 1, - raise_on_abnormal_exit: bool = True, -) -> list[int]: - """ - Check if processes have ended properly. Raise errors immediately - if any process has ended with a non-zero return code. - Return a list of process indices that have not ended yet. - """ - running = [] - start_time = time.time() - for i, process in enumerate(processes): - try: - current_time = time.time() - elapsed_time = current_time - start_time - # The processes are running in parallel. No need to wait for - # `timeout_in_sec` for each process. Only count the remaining ones. - wait_in_sec = max(0, timeout_in_sec - elapsed_time) - return_code = process.wait(timeout=wait_in_sec) - if return_code != 0: - error_message: str = ( - f"Process[{i}] {process.pid} exited with " - f"return code {return_code}. Command:\n " - f"{process.args!r}" - ) - if raise_on_abnormal_exit: - raise RuntimeError(error_message) - else: - logger.error(error_message) - except subprocess.TimeoutExpired: - running.append(i) - - return running - - -class PoolDeviceMeshProvider: - def __init__( - self, - hosts_per_mesh: int, - gpus_per_host: int, - client_proc: Proc, - mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]], - simulator_client: SimulatorClient, - ) -> None: - self._hosts_per_mesh = hosts_per_mesh - self._gpus_per_host = gpus_per_host - self._client_proc = client_proc - self._root_client_actor: ClientActor = ClientActor( - proc=client_proc, actor_name="root_client" - ) - self._mesh_world_state = mesh_world_state - self._simulator_client = simulator_client - # Keep track of this to create unique controller ids. - self._num_meshes_created = 0 - - def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh: - mesh_world_to_create = next( - ( - mesh_world - for mesh_world, is_created in self._mesh_world_state.items() - if not is_created - ), - None, - ) - assert mesh_world_to_create is not None, "No mesh world to create" - - worker_world, controller_id = mesh_world_to_create - # Create a new device mesh - backend_ctrl = RustController( - proc=self._client_proc, - client_actor=ClientActor( - self._client_proc, f"backend_controller_{self._num_meshes_created}" - ), - controller_id=controller_id, - worker_world_name=worker_world, - ) - self._num_meshes_created += 1 - client = Client( - backend_ctrl, - self._hosts_per_mesh * self._gpus_per_host, - self._gpus_per_host, - ) - dm = SimMesh( - client, - NDSlice( - offset=0, - sizes=[self._hosts_per_mesh, self._gpus_per_host], - strides=[self._gpus_per_host, 1], - ), - ("host", "gpu"), - self._simulator_client, - worker_world, - ) - self._mesh_world_state[mesh_world_to_create] = dm - - return dm - - -def sim_mesh_provider( - num_meshes: int, hosts_per_mesh: int, gpus_per_host: int -) -> Tuple[PoolDeviceMeshProvider, Bootstrap]: - mesh_world_state = {} - bootstrap = Bootstrap(num_meshes, mesh_world_state) - - client_proc_id = "client[0]" - client_proc: Proc = init_proc( - proc_id=client_proc_id, - bootstrap_addr=bootstrap.client_bootstrap_addr, - timeout=SIM_MESH_CLIENT_TIMEOUT, # unused - supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL, - listen_addr=bootstrap.client_listen_addr, - ) - dm_provider = PoolDeviceMeshProvider( - hosts_per_mesh, - gpus_per_host, - client_proc, - mesh_world_state, - bootstrap._simulator_client, - ) - return (dm_provider, bootstrap) diff --git a/python/monarch/timer/README.md b/python/monarch/timer/README.md index 61dad3f77..2c41c2417 100644 --- a/python/monarch/timer/README.md +++ b/python/monarch/timer/README.md @@ -44,13 +44,13 @@ We provide an example of CudaTimer within Monarch workloads at [example_monarch. ``` import torch from monarch import inspect, remote -from monarch.rust_local_mesh import local_mesh +from monarch.actor import this_host cuda_timer_start = remote("monarch.timer.remote_cuda_timer.cuda_timer_start", propagate="inspect") cuda_timer_stop = remote("monarch.timer.remote_cuda_timer.cuda_timer_stop", propagate="inspect") def main(): - mesh = local_mesh(hosts=1, gpus_per_host=1) + mesh = this_host().spawn_procs(per_host={"hosts": 1, "gpus": 1}) with mesh.activate(): a = torch.randn(1000, 1000, device="cuda") @@ -63,6 +63,5 @@ def main(): cuda_average_ms = get_cuda_timer_average_ms() local_cuda_avg_ms = inspect(cuda_average_ms).item() - mesh.exit() print(f"average time w/ CudaTimer: {local_cuda_avg_ms:.4f} (ms)") ``` diff --git a/python/monarch/timer/example_monarch.py b/python/monarch/timer/example_monarch.py index a07f481cb..3bdf67596 100644 --- a/python/monarch/timer/example_monarch.py +++ b/python/monarch/timer/example_monarch.py @@ -17,7 +17,7 @@ import torch from monarch import inspect, remote -from monarch.rust_local_mesh import local_mesh +from monarch.actor import this_host logger = logging.getLogger(__name__) @@ -42,33 +42,32 @@ def main() -> None: - with local_mesh(hosts=1, gpus_per_host=1) as mesh: - with mesh.activate(): - num_iterations = 5 + mesh = this_host().spawn_procs(per_host={"hosts": 1, "gpus": 1}) + with mesh.activate(): + num_iterations = 5 - a = torch.randn(1000, 1000, device="cuda") - b = torch.randn(1000, 1000, device="cuda") - torch.matmul(a, b) + a = torch.randn(1000, 1000, device="cuda") + b = torch.randn(1000, 1000, device="cuda") + torch.matmul(a, b) - total_dt = torch.zeros(1, dtype=torch.float64) + total_dt = torch.zeros(1, dtype=torch.float64) - for _ in range(num_iterations): - t0 = get_time_perfcounter() - torch.matmul(a, b) - total_dt += get_time_perfcounter() - t0 + for _ in range(num_iterations): + t0 = get_time_perfcounter() + torch.matmul(a, b) + total_dt += get_time_perfcounter() - t0 - for _ in range(num_iterations): - execution_timer_start() - torch.matmul(a, b) - execution_timer_stop() + for _ in range(num_iterations): + execution_timer_start() + torch.matmul(a, b) + execution_timer_stop() - cuda_average_ms = get_execution_timer_average_ms() - local_total_dt = inspect(total_dt) - local_cuda_avg_ms = inspect(cuda_average_ms) + cuda_average_ms = get_execution_timer_average_ms() + local_total_dt = inspect(total_dt) + local_cuda_avg_ms = inspect(cuda_average_ms) local_total_dt = local_total_dt.item() local_cuda_avg_ms = local_cuda_avg_ms.item() - mesh.exit() avg_perfcounter_ms = local_total_dt / num_iterations * 1000 print(f"average time w/ perfcounter: {avg_perfcounter_ms:.4f} (ms)") print(f"average time w/ ExecutionTimer: {local_cuda_avg_ms:.4f} (ms)") diff --git a/python/tests/test_controller.py b/python/tests/test_controller.py deleted file mode 100644 index d24d5ff1f..000000000 --- a/python/tests/test_controller.py +++ /dev/null @@ -1,835 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -import itertools -import logging -import re -import sys -import traceback -from contextlib import contextmanager - -import monarch -import monarch.random -import pytest - -import torch - -from monarch import fetch_shard, grad_function, grad_generator, Stream, Tensor - -from monarch._testing import TestingContext -from monarch.common.controller_api import LogMessage -from monarch.common.device_mesh import DeviceMesh, no_mesh -from monarch.common.invocation import DeviceException -from monarch.common.remote import remote -from monarch.common.tree import flattener -from monarch.rust_local_mesh import ( - ControllerParams, - local_mesh, - local_meshes_and_bootstraps, - LoggingLocation, - SocketType, - SupervisionParams, -) -from monarch_supervisor.logging import fix_exception_lines - - -def custom_excepthook(exc_type, exc_value, exc_traceback): - tb_lines = fix_exception_lines( - traceback.format_exception(exc_type, exc_value, exc_traceback) - ) - print("\n".join(tb_lines), file=sys.stderr) - - -sys.excepthook = custom_excepthook - - -@pytest.fixture(scope="module", autouse=True) -def testing_context(): - global local - with TestingContext() as local: - yield - - -@contextmanager -def local_rust_device_mesh( - hosts, - gpu_per_host, - activate: bool = True, - controller_params: ControllerParams | None = None, -): - with local_mesh( - hosts=hosts, - gpus_per_host=gpu_per_host, - socket_type=SocketType.UNIX, - logging_location=LoggingLocation.FILE, - controller_params=controller_params, - ) as dm: - try: - if activate: - with dm.activate(): - yield dm - else: - yield dm - dm.exit() - except Exception: - dm.client._shutdown = True - raise - - -panic = remote("__test_panic", propagate="inspect") - -remote_sleep = remote("time.sleep", propagate="inspect") - - -@pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="Not enough GPUs, this test requires at least 2 GPUs", -) -# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times -# out is not counted as a failure, so we set a more restrictive timeout to -# ensure we see a hard failure in CI. -@pytest.mark.timeout(120) -class TestController: - @classmethod - def local_device_mesh( - cls, - N, - gpu_per_host, - activate=True, - ): - return local.local_device_mesh( - N, - gpu_per_host, - activate, - ) - - def test_errors(self): - t = torch.rand(3, 4) - with self.local_device_mesh(2, 2) as device_mesh: - y = torch.rand(3, 4) - with pytest.raises(TypeError, match="LOCAL_TENSOR"): - t.add(y) - with pytest.raises(TypeError, match="WRONG_MESH"): - sm = device_mesh.slice(host=0) - with sm.activate(): - x = torch.rand(3, 4) - x.add(y) - - other = Stream("other") - t = torch.rand(10).cuda() - with pytest.raises(TypeError, match="WRONG_STREAM"): - with other.activate(): - t = t.reduce("host", "sum") - - def test_sub_mesh(self): - with self.local_device_mesh(2, 2) as device_mesh: - h0 = device_mesh.slice(host=0) - h1 = device_mesh.slice(host=1) - with h0.activate(): - _ = torch.rand(3, 4) - with h1.activate(): - _ = torch.rand(3, 4) - # Runs on a different mesh but should still work - - def test_fetch_result_device(self): - with self.local_device_mesh(2, 2): - on_gpu = torch.ones(2, 3, device="cuda") - on_cpu = torch.ones(2, 3, device="cpu") - - on_gpu_local = fetch_shard(on_gpu).result() - on_cpu_local = fetch_shard(on_cpu).result() - - assert on_gpu_local.device == torch.device("cpu") - assert on_cpu_local.device == torch.device("cpu") - - def test_dim1_mesh(self): - with self.local_device_mesh(2, 2, activate=False) as device_mesh: - mesh3d = device_mesh.split(host=("oh", "ih"), ih=1) - with mesh3d.activate(): - x = torch.ones(3, 4) - local_x = fetch_shard(x).result() - - assert torch.equal(local_x, torch.ones(3, 4)) - - def test_sub_mesh_use_only_one(self): - with self.local_device_mesh(2, 2, activate=False) as device_mesh: - h0 = device_mesh.slice(host=0) - - with h0.activate(): - x = torch.ones(3, 4) - local_x = fetch_shard(x) - - local_x = local_x.result(timeout=20) - assert torch.equal(local_x, torch.ones(3, 4)) - - def test_sub_mesh_process_grop(self): - with self.local_device_mesh(2, 2, activate=False) as device_mesh: - h0 = device_mesh.slice(host=0) - pg0 = h0.process_group(("gpu",)) - pg1 = h0.process_group(("gpu",)) - # Is there a way to functionally test that these two PG's aren't - # the same in the backend? - assert pg0 != pg1 - - def test_reduce(self): - with self.local_device_mesh(2, 2) as device_mesh: - x = ( - 12 * 2 * device_mesh.rank("host") - + 12 * device_mesh.rank("gpu") - + torch.arange(12, device="cuda").reshape(3, 4) - ) - y = x.reduce("gpu", "sum") - g = x.reduce("gpu", "stack") - with pytest.raises(TypeError, match="When scattering"): - x = x.reduce("gpu", "sum", scatter=True) - x = x.reshape(2, 6) - atoa = x.reduce("gpu", "stack", scatter=True) - rs = x.reduce("gpu", "sum", scatter=True) - rad = x.reduce((), "sum") - rade = x.reduce(("gpu", "host"), "sum") - with pytest.raises( - ValueError, match="is not valid for multiple dimensions" - ): - x.reduce((), "sum", scatter=True) - with pytest.raises( - ValueError, match="is not valid for multiple dimensions" - ): - x.reduce((), "stack") - with pytest.raises( - ValueError, match="is not valid for multiple dimensions" - ): - x.reduce((), "stack", scatter=True) - y_local = fetch_shard(y).result() - g_local = fetch_shard(g).result() - # TODO compute the expected values to compare agains in the below section - _ = fetch_shard(atoa).result() - _ = fetch_shard(rs).result() - rad_local = fetch_shard(rad).result() - rade_local = fetch_shard(rade).result() - - xs = { - (h, g): 12 * 2 * h + 12 * g + torch.arange(12, device="cpu").reshape(3, 4) - for h, g in itertools.product(range(2), range(2)) - } - - y_expected = xs[(0, 0)] + xs[(0, 1)] - g_expected = torch.stack([xs[(0, 0)], xs[(0, 1)]]) - assert torch.equal(y_local, y_expected) - assert torch.equal(g_local, g_expected) - rad_expected = (xs[(0, 0)] + xs[(0, 1)] + xs[(1, 0)] + xs[(1, 1)]).reshape( - rad_local.shape - ) - assert torch.equal(rad_local, rad_expected) - assert torch.equal(rade_local, rad_expected) - - # test is run on 4 GPUs, can't have mesh with 3 non-trivial dimensions - with self.local_device_mesh(2, 2, activate=False) as mesh2d: - device_mesh = mesh2d.split(host=("oh", "ih"), ih=1) - with device_mesh.activate(): - x = ( - 12 * 2 * device_mesh.rank("oh") - + 12 * device_mesh.rank("gpu") - + torch.arange(12, device="cuda").reshape(3, 4) - ) - y = x.reduce(("ih", "gpu"), "sum") - y_local = fetch_shard(y).result() - z = x.reduce(("oh", "gpu"), "sum") - z_local = fetch_shard(z).result() - - assert torch.equal(y_local, y_expected) - assert torch.equal(z_local, rad_expected.reshape(z_local.shape)) - - def test_reduce_out(self): - with self.local_device_mesh(2, 2): - inp = torch.rand(2, 4, device="cuda") - out_incorrect = torch.rand(2, 4, device="cuda") - out = torch.rand(4, device="cuda") - - with pytest.raises( - ValueError, match="Reduce expects the shape to be torch.Size." - ): - _ = inp.reduce("host", reduction="sum", scatter=True, out=out_incorrect) - - reduce_out = inp.reduce("host", reduction="sum", scatter=True) - local_out = fetch_shard(out).result() - local_reduce_out = fetch_shard(reduce_out).result() - assert out._fake is not reduce_out._fake - with no_mesh.activate(): - assert not torch.equal(local_out, local_reduce_out) - - reduce_out = inp.reduce("host", reduction="sum", scatter=True, out=out) - local_out = fetch_shard(out).result() - local_reduce_out = fetch_shard(reduce_out).result() - assert out._fake is reduce_out._fake - with no_mesh.activate(): - assert torch.equal(local_out, local_reduce_out) - - def test_fetch(self): - with self.local_device_mesh(2, 2) as device_mesh: - h = device_mesh.rank("host") - g = device_mesh.rank("gpu") - for hi in range(2): - for gi in range(2): - x, y = fetch_shard((h, g), {"host": hi, "gpu": gi}).result() - with no_mesh.activate(): - assert (hi, gi) == (x.item(), y.item()) - - def test_mutate(self): - with self.local_device_mesh(2, 2): - x = torch.rand(3, 4).cuda() - x.abs_() - s = Stream("other") - b, drop = s.borrow(x) - with pytest.raises(TypeError, match="would be mutated"): - x.abs_() - with s.activate(): - _ = b.add(b) - drop.drop() - x.abs_() - b, drop = s.borrow(x, mutable=True) - with s.activate(): - b.abs_() - drop.drop() - # del b - x.abs_() - - def test_movement(self): - with self.local_device_mesh(2, 2) as device_mesh: - sm0 = device_mesh.slice(host=0) - sm1 = device_mesh.slice(host=1) - - with sm0.activate(): - x = torch.rand(3, 4, device="cuda") - _ = x.to_mesh(sm1) - - a = torch.rand(3, 4, device="cuda") - - b = a.slice_mesh(host=0) - _ = b.to_mesh(sm0) - _ = b.to_mesh(sm1) - - def test_broadcast_one(self): - with self.local_device_mesh(2, 2) as device_mesh: - for dim in ("host", "gpu"): - subset = device_mesh.slice(**{dim: 1}) - with subset.activate(): - x = torch.rand(3, device="cuda") - y = x.to_mesh(device_mesh) - - with subset.activate(): - a = monarch.inspect(x) - with device_mesh.activate(): - b = monarch.inspect(y.reduce(dim, reduction="stack")) - with no_mesh.activate(): - assert torch.allclose(a.expand(2, -1), b, rtol=0, atol=0) - - def test_broadcast_two(self): - with self.local_device_mesh(2, 2) as device_mesh: - subset = device_mesh.slice(host=1, gpu=1) - with subset.activate(): - x = torch.rand(3, device="cuda") - y = x.to_mesh(device_mesh) - - with subset.activate(): - a = monarch.inspect(x) - with device_mesh.activate(): - b = monarch.inspect( - y.reduce("host", reduction="stack").reduce("gpu", reduction="stack") - ) - with no_mesh.activate(): - assert torch.allclose(a.expand(2, 2, -1), b, rtol=0, atol=0) - - def test_autograd(self): - with self.local_device_mesh(2, 2) as device_mesh: - x = torch.rand(3, 4, requires_grad=True) - y = torch.rand(4, 3, requires_grad=True) - z = torch.rand(3, requires_grad=True) - - foo = (x @ y + z).sum() - with no_mesh.activate(): - # check backward restores forward mesh - for t in grad_generator(foo, [z, y, x]): - with device_mesh.activate(): - fetch_shard(t).result() - - def test_mesh_semantics(self): - with self.local_device_mesh(2, 2) as device_mesh: - host0 = device_mesh.slice(host=0) - host1 = device_mesh.slice(host=1) - with host0.activate(): - x = torch.randn(5) - y = x * 5 - with host1.activate(): - a = torch.randn(5) - b = a * 5 - x.cos() - y.cos() - b.cos() - - def test_autograd_multi_mesh(self): - @grad_function - def to_mesh(x: Tensor, mesh: DeviceMesh): - omesh = x.mesh - - def backward(grad_x: Tensor): - print(grad_x.mesh, omesh) - return grad_x.to_mesh(omesh), None - - return x.to_mesh(mesh), backward - - with self.local_device_mesh(2, 2) as device_mesh: - host0 = device_mesh.slice(host=0) - host1 = device_mesh.slice(host=1) - with host0.activate(): - x = torch.rand(3, 4, requires_grad=True, device="cuda") - y = torch.rand(4, 3, requires_grad=True, device="cuda") - t = x @ y - t = to_mesh(t, host1) - with host1.activate(): - z = torch.rand(3, requires_grad=True, device="cuda") - foo = (t + z).sum() - - for r in grad_generator(foo, [z, y, x]): - with r.mesh.activate(): - print(fetch_shard(r).result()) - - def test_many(self): - with self.local_device_mesh(2, 2): - x = torch.rand(3, 4) - for _ in range(2048): - x = x + torch.rand(3, 4) - fetch_shard(x).result() - - def test_flattener(self): - e = (8, 9, {"a": 10, "b": 11}) - flatten = flattener(e) - e2 = (0, 1, {"a": 2, "b": 3}) - assert [0, 1, 2, 3] == flatten(e2) - - def test_torch_tensor(self): - with self.local_device_mesh(2, 2): - t = torch.tensor([1, 2, 4]) - tc = torch.tensor([1, 2, 4], device="cuda") - t2 = fetch_shard(t).result() - tc2 = fetch_shard(tc).result() - assert torch.allclose(t2, torch.tensor([1, 2, 4])) - assert torch.allclose(tc2, torch.tensor([1, 2, 4], device="cpu")) - - def test_to_mesh_aliasing(self): - with self.local_device_mesh(2, 2) as mesh: - p2p_stream = Stream("p2p_stream") - - ppmesh = mesh.flatten("all").split( - all=( - "dp", - "pp", - ), - pp=2, - ) - pp_meshes = [ppmesh.slice(pp=i) for i in range(2)] - - with ppmesh.activate(): - with pp_meshes[0].activate(): - x = torch.randn((3, 3), device="cuda") - x_borrowed_tensor, x_borrow = p2p_stream.borrow(x) - with p2p_stream.activate(): - y_on_mesh_1_p2p_stream = x_borrowed_tensor.to_mesh(pp_meshes[1]) - - with pp_meshes[1].activate(): - x_borrow.drop() - y_on_mesh_1_default_stream, y_borrow = ( - monarch.get_active_stream().borrow(y_on_mesh_1_p2p_stream) - ) - - monarch.inspect(y_on_mesh_1_default_stream) - y_borrow.drop() - - def test_to_mesh_cow(self): - with self.local_device_mesh(2, 2) as mesh: - t = torch.zeros((), device="cuda") - t2 = t.to_mesh(mesh) - t.add_(1) - assert monarch.inspect(t2).item() == 0 - assert monarch.inspect(t).item() == 1 - - def test_to_mesh_stream(self): - other = monarch.Stream("other") - with self.local_device_mesh(2, 2) as mesh: - m0 = mesh.slice(host=0) - m1 = mesh.slice(host=1) - with m0.activate(): - t2 = torch.rand(3, 4, device="cuda").to_mesh(m1, stream=other) - with m1.activate(), other.activate(): - # assert doesn't fail - monarch.inspect(t2 + t2) - - def test_dropped_trace(self): - with self.local_device_mesh(2, 2) as _: - x = torch.rand(4, 4).cuda() - s = Stream("other") - b, drop = s.borrow(x) - drop.drop() - with s.activate(): - pattern = re.compile( - ".*tensor.*is dropped at.*.*drop.drop().*", flags=re.DOTALL - ) - with pytest.raises(TypeError, match=pattern): - _ = b.abs() - - def test_sub_mesh_reduce(self): - with self.local_device_mesh(2, 2) as device_mesh: - host1 = device_mesh.slice(host=1) - with host1.activate(): - myrank = ( - (device_mesh.rank("host") + 1) * 2 + device_mesh.rank("gpu") + 1 - ) - x = torch.ones((3, 4), device="cuda") * myrank - reduce = x.reduce("gpu", "sum") - local_reduce = fetch_shard(reduce).result() - - assert torch.equal(local_reduce, torch.ones(3, 4) * 11) - - def test_size(self): - with self.local_device_mesh(2, 2) as device_mesh: - assert device_mesh.size(["host", "gpu"]) == 4 - - def test_random_state(self): - with self.local_device_mesh(2, 2) as device_mesh: - monarch.random.make_deterministic() - for device in ("cpu", "cuda"): - a = monarch.random.get_state() - monarch.inspect(a) - first = torch.rand(1, device=device) - monarch.random.set_state(a) - second = torch.rand(1, device=device) - f, s = monarch.inspect((first, second)) - with no_mesh.activate(): - assert torch.allclose(f, s, atol=0, rtol=1) - seed = device_mesh.rank(["host", "gpu"]) + 4 - s2 = monarch.random.new_state(seed) - s3 = monarch.random.new_state(seed) - monarch.random.set_state(s2) - r0 = torch.rand(1, device=device) - if device == "cuda": - for d in ("host", "gpu"): - r0 = r0.reduce(d, reduction="stack") - monarch.random.set_state(s3) - r1 = torch.rand(1, device=device) - if device == "cuda": - for d in ("host", "gpu"): - r1 = r1.reduce(d, reduction="stack") - r2, r3 = monarch.inspect((r0, r1)) - monarch.random.set_state(a) - with no_mesh.activate(): - assert torch.allclose(r2, r3, atol=0, rtol=0) - assert not torch.allclose(r2, f, atol=0, rtol=0) - - def test_torch_op_with_optional_tensors(self): - """ - This test ensures that for torch ops like LayerNorm, which allow for - optional tensor arguments, the controller serializes monarch tensors - correctly as Refs instead of as IValues. - """ - with self.local_device_mesh(2, 2): - x = torch.rand(3, 4, device="cuda") - # When bias and elementwise_affine are true, extra tensors are passed through optional - # fields inside LayerNorm. When they are false, None is passed to the same optional fields. - # If we are handling serialization correctly, there shouldn't be a crash in either case. - layer_norm_with_vals = torch.nn.LayerNorm( - 4, device="cuda", bias=True, elementwise_affine=True - ) - layer_norm_with_none = torch.nn.LayerNorm( - 4, device="cuda", bias=False, elementwise_affine=False - ) - monarch.inspect(layer_norm_with_vals(x)) - monarch.inspect(layer_norm_with_none(x)) - - def test_reduce_pytree(self): - with self.local_device_mesh(2, 2) as device_mesh: - a = device_mesh.rank(("gpu", "host")) + torch.zeros((1,), device="cuda") - b = device_mesh.rank(("gpu", "host")) + torch.ones((1,), device="cuda") - - tensor_dict = {"a": a, "b": b} - _ = monarch.reduce_(tensor_dict, dims=("gpu", "host"), reduction="sum") - reduced_tensor_dict = monarch.reduce( - tensor_dict, dims=("gpu", "host"), reduction="sum" - ) - reduced_a = fetch_shard(reduced_tensor_dict["a"]).result() - reduced_b = fetch_shard(reduced_tensor_dict["b"]).result() - reduced_a_inplace = fetch_shard(tensor_dict["a"]).result() - reduced_b_inplace = fetch_shard(tensor_dict["b"]).result() - - assert torch.equal(reduced_a_inplace, torch.tensor([6.0])) - assert torch.equal(reduced_b_inplace, torch.tensor([10.0])) - assert torch.equal(reduced_a, torch.tensor([24.0])) - assert torch.equal(reduced_b, torch.tensor([40.0])) - - def test_to_mesh_pytree(self): - with self.local_device_mesh(2, 2) as device_mesh: - host0 = device_mesh.slice(host=0) - host1 = device_mesh.slice(host=1) - - with host0.activate(): - a = torch.zeros((1,), device="cuda") - b = torch.ones((1,), device="cuda") - tensor_dict = {"a": a, "b": b} - moved_tensor_dict = monarch.to_mesh(tensor_dict, host1) - - with host1.activate(): - moved_tensor_dict["a"].add_(1) - moved_tensor_dict["b"].add_(1) - - moved_tensor_a = monarch.inspect(moved_tensor_dict["a"]) - moved_tensor_b = monarch.inspect(moved_tensor_dict["b"]) - - host0.exit() - host1.exit() - - assert torch.equal(moved_tensor_a, torch.tensor([1.0])) - assert torch.equal(moved_tensor_b, torch.tensor([2.0])) - - def test_hanging_error(self): - with self.local_device_mesh(2, 2) as device_mesh: - remote(lambda: torch.rand(3) + torch.rand(4), propagate=lambda: None)() - - with pytest.raises(Exception, match="The size of tensor"): - device_mesh.client.shutdown() - - def test_slice_mesh_pytree(self): - with self.local_device_mesh(2, 2) as device_mesh: - a = device_mesh.rank(("host")) + torch.zeros((1,), device="cuda") - b = device_mesh.rank(("host")) + torch.ones((1,), device="cuda") - - tensor_dict = {"a": a, "b": b} - host0_slices = monarch.slice_mesh(tensor_dict, host=0) - host1_slices = monarch.slice_mesh(tensor_dict, host=1) - - host0 = device_mesh.slice(host=0) - host1 = device_mesh.slice(host=1) - - host0_tensors = monarch.to_mesh(host0_slices, host0) - host1_tensors = monarch.to_mesh(host1_slices, host1) - - with host0.activate(): - _ = monarch.reduce_(host0_tensors, dims=("gpu"), reduction="sum") - host0_a = fetch_shard(host0_tensors["a"]).result() - host0_b = fetch_shard(host0_tensors["b"]).result() - - with host1.activate(): - _ = monarch.reduce_(host1_tensors, dims=("gpu"), reduction="sum") - host1_a = fetch_shard(host1_tensors["a"]).result() - host1_b = fetch_shard(host1_tensors["b"]).result() - - host0.exit() - host1.exit() - - assert torch.equal(host0_a, torch.tensor([0.0])) - assert torch.equal(host0_b, torch.tensor([2.0])) - assert torch.equal(host1_a, torch.tensor([2.0])) - assert torch.equal(host1_b, torch.tensor([4.0])) - - -def test_panicking_worker(): - with pytest.raises(DeviceException, match="__test_panic called"): - with local_rust_device_mesh(1, 1) as _: - panic() - # induce a sync to allow the panic to propagate back - _ = fetch_shard(torch.ones(2, 3)).result() - - -# TODO - re-enable after resolving T232206970 -@pytest.mark.oss_skip -def test_timeout_warning(caplog): - timeout = 3 - with local_rust_device_mesh( - 1, - 2, - True, - controller_params=ControllerParams(1, timeout, 100, False), - ) as dm: - for _ in range(3): - dm.client.new_node([], []) - - assert dm.client.inner.next_message(timeout * 3) is None - - remote_sleep(timeout * 2) - for _ in range(3): - dm.client.new_node([], []) - - with caplog.at_level(logging.WARNING, logger=dm.client.__module__): - has_message = dm.client.handle_next_message(120) - assert has_message - assert ( - f"ranks 1, 0 have operations that have not completed after {timeout} seconds" - in caplog.text - ) or ( - f"ranks 0, 1 have operations that have not completed after {timeout} seconds" - in caplog.text - ) - - -def test_timeout_failure(): - timeout = 3 - with local_rust_device_mesh( - 1, - 1, - True, - controller_params=ControllerParams(1, timeout, 100, True), - ) as dm: - for _ in range(3): - dm.client.new_node([], []) - - assert dm.client.inner.next_message(timeout * 3) is None - - remote_sleep(timeout * 2) - for _ in range(3): - dm.client.new_node([], []) - - for _ in range(5): - result = dm.client.inner.next_message(1) - if result is None: - continue - if isinstance(result, LogMessage): - continue - if result.error is None: - continue - assert isinstance(result.error, DeviceException) - assert "crashed" in result.error.message in result.error.message - assert "mesh_0_worker[0].worker[0]" in result.error.message - assert ( - f"ranks 0 have operations that have not completed after {timeout} seconds" - in result.error.frames[0].name - ) - - -def test_supervision_heartbeat_failure(): - (dms, bootstrap) = local_meshes_and_bootstraps( - meshes=1, - hosts_per_mesh=1, - gpus_per_host=2, - socket_type=SocketType.UNIX, - logging_location=LoggingLocation.DEFAULT, - supervision_params=SupervisionParams( - # Set a low timeout so heatbeat failure can be detected faster. - update_timeout_in_sec=10, - query_interval_in_sec=1, - update_interval_in_sec=1, - ), - ) - assert len(dms) == 1 - dm = dms[0] - - # Kill a process of a worker actor. This should trigger supervision - # heartbeat failure event. - # Index 0 and 1 are system process and controller process respectively. - process = bootstrap.processes[2] - process.kill() - - for _ in range(20): - # poll the next message in order to get the supervision failure - result = dm.client.inner.next_message(3) - if result is None: - continue - if result.error is None: - continue - assert isinstance(result.error, DeviceException) - assert "crashed" in result.error.message - return - - dm.exit() - raise AssertionError("Should have failed supervision health check") - - -def test_supervision_system_actor_down(): - (dms, bootstrap) = local_meshes_and_bootstraps( - meshes=1, - hosts_per_mesh=1, - gpus_per_host=2, - socket_type=SocketType.UNIX, - logging_location=LoggingLocation.DEFAULT, - supervision_params=SupervisionParams( - # Set a low timeout so heatbeat failure can be detected faster. - update_timeout_in_sec=10, - query_interval_in_sec=1, - update_interval_in_sec=1, - ), - ) - assert len(dms) == 1 - dm = dms[0] - - # Index 0 is system process - process = bootstrap.processes[0] - process.kill() - - try: - for _ in range(20): - # poll the next message in order to get the supervision failure - dm.client.inner.next_message(3) - except RuntimeError as e: - assert "actor has been stopped" in str(e) - return - - dm.exit() - raise AssertionError("Should have failed supervision health check") - - -def test_supervision_controller_actor_down(): - (dms, bootstrap) = local_meshes_and_bootstraps( - meshes=1, - hosts_per_mesh=1, - gpus_per_host=2, - socket_type=SocketType.UNIX, - logging_location=LoggingLocation.DEFAULT, - supervision_params=SupervisionParams( - # Set a low timeout so heatbeat failure can be detected faster. - update_timeout_in_sec=10, - query_interval_in_sec=1, - update_interval_in_sec=1, - ), - ) - assert len(dms) == 1 - dm = dms[0] - - # Index 1 is controller process - process = bootstrap.processes[1] - process.kill() - - for _ in range(20): - # poll the next message in order to get the supervision failure - result = dm.client.inner.next_message(3) - if result is None: - continue - if result.error is None: - continue - assert isinstance(result.error, DeviceException) - assert "mesh_0_controller[0].controller[0] crashed" in result.error.message - return - - dm.exit() - raise AssertionError("Should have failed supervision health check") - - -def a_function_called_by_a_live_function(x): - return 2 * x - - -def a_live_function_call_by_a_live_function(x): - return 3 * x - - -def test_delete_refs(): - with local_mesh( - hosts=2, - gpus_per_host=2, - socket_type=SocketType.UNIX, - logging_location=LoggingLocation.DEFAULT, - ) as dm: - dm.client.delete_ref(dm, 1) - dm.client.delete_ref(dm, 2) - assert len(dm.client._pending_del[dm]) == 2 - dm.client.flush_deletes() - assert len(dm.client._pending_del[dm]) == 0 diff --git a/python/tests/test_fault_tolerance.py b/python/tests/test_fault_tolerance.py deleted file mode 100644 index a31e6f332..000000000 --- a/python/tests/test_fault_tolerance.py +++ /dev/null @@ -1,383 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -import random -import time -from typing import Optional - -import pytest -import torch - -try: - from later.unittest import TestCase -except ModuleNotFoundError: - from unittest import TestCase -from monarch import fetch_shard, remote -from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus, no_mesh -from monarch.common.invocation import DeviceException, RemoteException -from monarch.rust_backend_mesh import MeshWorld, PoolDeviceMeshProvider -from monarch.rust_local_mesh import ( - Bootstrap, - local_mesh_provider, - local_meshes_and_bootstraps, - LoggingLocation, - SocketType, - SupervisionParams, -) - - -def _do_bogus_tensor_work( - x: torch.Tensor, y: torch.Tensor, fail_rank: Optional[int] = None -) -> torch.Tensor: - return x + y # real function actually does x @ y - - -do_bogus_tensor_work = remote( - "monarch.worker._testing_function.do_bogus_tensor_work", - propagate=_do_bogus_tensor_work, -) - - -def mesh_provider( - meshes: int = 2, - hosts_per_mesh: int = 1, - gpus_per_host: int = 1, - # pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type. -) -> tuple[PoolDeviceMeshProvider, Bootstrap]: - return local_mesh_provider( - meshes=meshes, - hosts_per_mesh=hosts_per_mesh, - gpus_per_host=gpus_per_host, - socket_type=SocketType.UNIX, - logging_location=LoggingLocation.DEFAULT, - supervision_params=SupervisionParams( - update_timeout_in_sec=10, # Fail fast - query_interval_in_sec=1, - update_interval_in_sec=1, - ), - auto_epoch=True, - ) - - -def local_meshes( - meshes: int = 2, - hosts_per_mesh: int = 1, - gpus_per_host: int = 1, -) -> tuple[list[DeviceMesh], Bootstrap]: - return local_meshes_and_bootstraps( - meshes=meshes, - hosts_per_mesh=hosts_per_mesh, - gpus_per_host=gpus_per_host, - socket_type=SocketType.UNIX, - logging_location=LoggingLocation.DEFAULT, - supervision_params=SupervisionParams( - update_timeout_in_sec=10, # Fail fast - query_interval_in_sec=1, - update_interval_in_sec=1, - ), - auto_epoch=True, - ) - - -# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times -# out is not counted as a failure, so we set a more restrictive timeout to -# ensure we see a hard failure in CI. -# The timeout is set to 250s as the failover takes longer than other tests. -@pytest.mark.timeout(250) -class TestFaultTolerance(TestCase): - def test_mesh_provider(self) -> None: - # Create multiple meshes using mesh provider - replicas = 4 - provider, bootstrap = mesh_provider(meshes=replicas) - meshes: list[DeviceMesh] = [] - while len(meshes) < replicas: - dm = provider.new_mesh() - meshes.append(dm) - - statuses = provider._root_client.world_status() - for _, status in statuses.items(): - assert ( - DeviceMeshStatus(status) != DeviceMeshStatus.UNHEALTHY - ), f"unexpected unhealthy mesh; world status: {statuses}" - - # Check that all meshes are initially live - for mesh in meshes: - with mesh.activate(): - t = torch.ones(1) - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(1)) - - # Simulate a failure by killing one of the processes - bootstrap.processes[-1].kill() - - # Find unhealthy mesh - # Mix user and device errors - unhealthy_meshes = [] - for mesh in meshes: - with mesh.activate(): - # Send a call to trigger a failure - x = torch.rand(3, 4) - y = torch.rand(3, 4) - z = do_bogus_tensor_work(x, y) - try: - _ = fetch_shard(z).result() - except RemoteException: - pass - except DeviceException as e: - # Device error - unhealthy_meshes.append(mesh) - mesh.exit(e) - - self.assertEqual(len(unhealthy_meshes), 1) - - for _ in range(20): - statuses = provider._root_client.world_status() - # The 4th worker and controller worlds should have been evicted. - if len(statuses) != 6: - time.sleep(1) - continue - else: - for status in statuses.values(): - assert DeviceMeshStatus(status) == DeviceMeshStatus.LIVE - return - raise RuntimeError(f"Unexpected world statuses: {statuses}") - - def test_worker_supervision_failure(self) -> None: - meshes, bootstrap = local_meshes(meshes=1) - assert len(meshes) == 1 - mesh = meshes[0] - - # Check the mesh initially functional - with mesh.activate(): - t = torch.ones(1) - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(1)) - - # Simulate a failure by killing one of the processes - bootstrap.processes[-1].kill() - - # A device error will be raised - with mesh.activate(): - t = torch.ones(1) - with self.assertRaisesRegex(DeviceException, r"crashed"): - local_t = fetch_shard(t).result() - - def test_multi_mesh_failure_isolation(self) -> None: - replicas = 4 - provider, bootstrap = mesh_provider(meshes=replicas) - meshes: list[DeviceMesh] = [] - while len(meshes) < replicas: - dm = provider.new_mesh() - meshes.append(dm) - - # Check the meshes initially functional - for mesh in meshes: - with mesh.activate(): - t = torch.ones(1) - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(1)) - - initial_size = len(provider._root_client.world_status()) - - # Simulate a failure by killing one of the processes - bootstrap.processes[-1].kill() - - # Mix user and device errors - healthy_meshes = [] - unhealthy_meshes = [] - for mesh in meshes: - with mesh.activate(): - # Send a call to trigger a failure - x = torch.rand(3, 4) - y = torch.rand(3, 4) - z = do_bogus_tensor_work(x, y) - try: - _ = fetch_shard(z).result() - except RemoteException: - # User error - fetch_shard(x).result() - healthy_meshes.append(mesh) - except DeviceException as e: - # Device error - unhealthy_meshes.append(mesh) - mesh.exit(e) - - self.assertEqual(len(healthy_meshes), replicas - 1) - self.assertEqual(len(unhealthy_meshes), 1) - - while True: - size = len(provider._root_client.world_status()) - if size == initial_size - 2: - break - - # The healthy meshes should still be functional - for mesh in healthy_meshes: - with mesh.activate(): - t = torch.ones(1) - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(1)) - - def test_out_of_order_receive(self) -> None: - meshes, _ = local_meshes(meshes=8) - - # Check the meshes initially functional - ts = [] - for i, mesh in enumerate(meshes): - with mesh.activate(): - t = torch.ones(i + 1) - ts.append(t) - - # Shuffle the meshes to makes sure the client is able to dispatch results in order - indices = list(range(len(meshes))) - shuffled_meshes = list(zip(indices, meshes, ts)) - random.shuffle(shuffled_meshes) - for i, mesh, t in shuffled_meshes: - with mesh.activate(): - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(i + 1)) - - def test_mesh_shrink_and_grow(self) -> None: - # Create multiple meshes using mesh provider - replicas = 4 - provider, bootstrap = mesh_provider(meshes=replicas) - meshes: list[DeviceMesh] = [] - while len(meshes) < replicas: - dm = provider.new_mesh() - meshes.append(dm) - - worlds = len(provider._root_client.world_status()) - assigned_meshes = provider._mesh_map - assert len(assigned_meshes) == replicas - - # Happy path - for i, mesh in enumerate(meshes): - with mesh.activate(): - t = torch.ones(i + 1) - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(i + 1)) - - # Kill a worker - mesh_to_kill: MeshWorld = list(bootstrap.mesh_worlds.keys())[1] - procs = bootstrap.mesh_worlds[mesh_to_kill] - assert len(procs) == 2 - procs[-1].kill() - - # The killed mesh will become unhealthy - healthy_meshes = [] - unhealthy_meshes = [] - for i, mesh in enumerate(meshes): - with mesh.activate(): - try: - t = torch.ones(i + 1) - local_t = fetch_shard(t).result() - with no_mesh.activate(): - assert torch.equal(local_t, torch.ones(i + 1)) - healthy_meshes.append(mesh) - except DeviceException as e: - unhealthy_meshes.append(mesh) - mesh.exit(e) - assert len(healthy_meshes) == replicas - 1 - assert len(unhealthy_meshes) == 1 - - # Restart the mesh - for proc in procs: - proc.kill() - - # We should be able to acquire a new mesh without waiting for the old mesh to be evicted - (worker_world, controller_id) = mesh_to_kill - bootstrap.launch_mesh(controller_id=controller_id, worker_world=worker_world) - - dm = provider.new_mesh() - healthy_meshes.append(dm) - - # We could have 4 or 5 meshes depending on if the unhealthy mesh is evicted - assigned_meshes = provider._mesh_map - assert len(assigned_meshes) >= replicas - - # We are happy again - assert len(healthy_meshes) == replicas - for i, mesh in enumerate(healthy_meshes): - with mesh.activate(): - t = torch.ones(i + 1) - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(i + 1)) - - # Old world should be evicted and new world should be spawned. So we ended up with the same number of worlds. - while len((provider._root_client.world_status())) != worlds: - # We expect to evict both controller and worker worlds from the same mesh. - time.sleep(1) - - # Eventually, we only have 4 healthy meshes - assigned_meshes = provider._mesh_map - while len(assigned_meshes) != replicas: - with self.assertRaisesRegex( - TimeoutError, r"Could not find a healthy world" - ): - _ = provider.new_mesh(timeout_in_sec=1) - assigned_meshes = provider._mesh_map - time.sleep(1) - - def test_kill_controller(self) -> None: - # Create multiple meshes using mesh provider - replicas = 2 - provider, bootstrap = mesh_provider(meshes=replicas) - meshes: list[DeviceMesh] = [] - while len(meshes) < replicas: - dm = provider.new_mesh() - meshes.append(dm) - - # Happy path - for i, mesh in enumerate(meshes): - with mesh.activate(): - t = torch.ones(i + 1) - local_t = fetch_shard(t).result() - assert torch.equal(local_t, torch.ones(i + 1)) - - # Kill a controller - mesh_to_kill: MeshWorld = list(bootstrap.mesh_worlds.keys())[1] - procs = bootstrap.mesh_worlds[mesh_to_kill] - assert len(procs) == 2 - procs[0].kill() - - # We should be able to detect the failure - healthy_meshes = [] - detected_failure = False - for i, mesh in enumerate(meshes): - with mesh.activate(): - try: - t = torch.ones(i + 1) - local_t = fetch_shard(t).result() - with no_mesh.activate(): - assert torch.equal(local_t, torch.ones(i + 1)) - healthy_meshes.append(mesh) - except DeviceException: - detected_failure = True - assert len(healthy_meshes) == replicas - 1 - assert detected_failure - - def test_late_client_attaching(self) -> None: - provider, _ = mesh_provider(meshes=1) - - # Wait for the meshes to be healthy - healthy_meshes = 0 - while healthy_meshes < 2: - healthy_meshes = 0 - statuses = provider._root_client.world_status() - for _, status in statuses.items(): - if DeviceMeshStatus(status) == DeviceMeshStatus.LIVE: - healthy_meshes += 1 - time.sleep(1) - - # Sleep long enough to allow those "hidden messages" to be sent - time.sleep(15) - - # Those "hidden messages" should not cause a trouble before a client is ready - mesh = provider.new_mesh() - with mesh.activate(): - t = torch.ones(1) - fetch_shard(t).result() diff --git a/python/tests/test_sim_backend.py b/python/tests/test_sim_backend.py deleted file mode 100644 index 5d4c704ae..000000000 --- a/python/tests/test_sim_backend.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from contextlib import contextmanager -from typing import Generator -from unittest import TestCase - -import pytest - -import torch -from monarch import fetch_shard -from monarch.common.device_mesh import DeviceMesh -from monarch.sim_mesh import sim_mesh - - -@contextmanager -def local_sim_mesh( - hosts: int = 1, - # TODO: support multiple gpus in a mesh. - gpu_per_host: int = 1, - activate: bool = True, -) -> Generator[DeviceMesh, None, None]: - dms = sim_mesh(n_meshes=1, hosts=hosts, gpus_per_host=gpu_per_host) - dm = dms[0] - try: - if activate: - with dm.activate(): - yield dm - else: - yield dm - dm.exit() - except Exception: - dm.client._shutdown = True - raise - - -# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited -@pytest.mark.oss_skip -class TestSimBackend(TestCase): - def test_local_mesh_setup(self): - with local_sim_mesh(): - t = torch.zeros(3, 4) - t.add_(1) - local_t = fetch_shard(t).result() - # consider support specifying the return value in the mock worker. - assert local_t is not None diff --git a/scripts/build_monarch_for_docs.sh b/scripts/build_monarch_for_docs.sh index acba6429a..6e749cd83 100755 --- a/scripts/build_monarch_for_docs.sh +++ b/scripts/build_monarch_for_docs.sh @@ -35,8 +35,7 @@ import sys modules = [ 'monarch.fetch', 'monarch.gradient_generator', - 'monarch.notebook', - 'monarch.rust_local_mesh' + 'monarch.notebook' ] failed_modules = [] diff --git a/scripts/generate_cargo_deps_graph.py b/scripts/generate_cargo_deps_graph.py new file mode 100644 index 000000000..3780e708b --- /dev/null +++ b/scripts/generate_cargo_deps_graph.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Generate a Graphviz dependency graph from Rust Cargo.toml files. + +This script scans a directory tree for Cargo.toml files, extracts dependencies +between local crates (those with 'path' dependencies), and generates a .dot file +visualizing the dependency relationships. +""" + +import argparse +import os +import re +import sys +from pathlib import Path +from typing import Dict, List, Tuple, TypedDict + + +class DependencyInfo(TypedDict): + name: str + path: str + type: str + + +class TomlParseResult(TypedDict): + package_name: str | None + dependencies: List[DependencyInfo] + + +def parse_toml_simple(content: str) -> TomlParseResult: + """ + Simple TOML parser that extracts package name and path dependencies. + This is a minimal parser focused on what we need for dependency extraction. + """ + result: TomlParseResult = {"package_name": None, "dependencies": []} + + in_dependencies = False + in_dev_dependencies = False + in_build_dependencies = False + + for line in content.split("\n"): + line = line.strip() + + # Skip comments and empty lines + if not line or line.startswith("#"): + continue + + # Check for section headers + if line.startswith("["): + in_dependencies = line == "[dependencies]" + in_dev_dependencies = line == "[dev-dependencies]" + in_build_dependencies = line == "[build-dependencies]" + + # Extract package name + if line.startswith("[package]"): + in_dependencies = False + continue + + # Extract package name + if "name =" in line and result["package_name"] is None: + match = re.search(r'name\s*=\s*"([^"]+)"', line) + if match: + result["package_name"] = match.group(1) + + # Extract path dependencies + if in_dependencies or in_dev_dependencies or in_build_dependencies: + # Match patterns like: crate_name = { path = "../path" } + # Note: crate names can contain hyphens, so we use [\w-]+ instead of \w+ + path_match = re.search(r'([\w-]+)\s*=\s*\{[^}]*path\s*=\s*"([^"]+)"', line) + if path_match: + dep_name = path_match.group(1) + dep_path = path_match.group(2) + dep_type = ( + "dev" + if in_dev_dependencies + else ("build" if in_build_dependencies else "normal") + ) + result["dependencies"].append( + {"name": dep_name, "path": dep_path, "type": dep_type} + ) + + return result + + +def find_cargo_tomls(root_dir: Path) -> List[Path]: + """Find all Cargo.toml files in the directory tree.""" + cargo_files = [] + for dirpath, _dirnames, filenames in os.walk(root_dir): + if "Cargo.toml" in filenames: + cargo_files.append(Path(dirpath) / "Cargo.toml") + return cargo_files + + +def extract_dependencies( + root_dir: Path, + include_dev: bool = False, + include_build: bool = False, + verbose: bool = False, +) -> Tuple[Dict[str, str], Dict[str, List[Tuple[str, str]]]]: + """ + Extract dependency information from all Cargo.toml files. + + Returns: + - crate_map: Dict mapping crate names to their directory paths + - dependencies: Dict mapping crate names to lists of (dependency_name, dep_type) tuples + """ + cargo_files = find_cargo_tomls(root_dir) + + # Map crate name to its directory + crate_map: Dict[str, str] = {} + # Map crate name to list of its dependencies + dependencies: Dict[str, List[Tuple[str, str]]] = {} + + for cargo_path in cargo_files: + try: + with open(cargo_path, "r") as f: + content = f.read() + + parsed = parse_toml_simple(content) + + # Skip workspace manifests and files without package names + package_name = parsed["package_name"] + if not package_name: + if verbose: + print( + f"Skipping {cargo_path} (no package name, likely workspace manifest)", + file=sys.stderr, + ) + continue + + # Type narrowing: package_name is guaranteed to be str here + package_dir = str(cargo_path.parent.relative_to(root_dir)) + + crate_map[package_name] = package_dir + + if verbose: + print(f"Found crate '{package_name}' at {package_dir}", file=sys.stderr) + + # Collect dependencies + deps = [] + for dep in parsed["dependencies"]: + dep_type = dep["type"] + + # Filter based on dependency type + if dep_type == "dev" and not include_dev: + if verbose: + print( + f" Skipping dev-dependency: {dep['name']}", file=sys.stderr + ) + continue + if dep_type == "build" and not include_build: + if verbose: + print( + f" Skipping build-dependency: {dep['name']}", + file=sys.stderr, + ) + continue + + deps.append((dep["name"], dep_type)) + if verbose: + print( + f" Found {dep_type}-dependency: {dep['name']}", file=sys.stderr + ) + + if deps: + dependencies[package_name] = deps + + except Exception as e: + print(f"Warning: Failed to parse {cargo_path}: {e}", file=sys.stderr) + import traceback + + if verbose: + traceback.print_exc() + continue + + return crate_map, dependencies + + +def filter_local_dependencies( + crate_map: Dict[str, str], dependencies: Dict[str, List[Tuple[str, str]]] +) -> Dict[str, List[Tuple[str, str]]]: + """ + Filter dependencies to only include those that are local crates (present in crate_map). + """ + filtered = {} + for crate, deps in dependencies.items(): + local_deps = [ + (dep_name, dep_type) for dep_name, dep_type in deps if dep_name in crate_map + ] + if local_deps: + filtered[crate] = local_deps + return filtered + + +def generate_graphviz( + crate_map: Dict[str, str], + dependencies: Dict[str, List[Tuple[str, str]]], + output_file: Path, + title: str = "Cargo Dependencies", +): + """Generate a Graphviz .dot file from the dependency information.""" + + # Collect all crates (even those without dependencies) + all_crates = set(crate_map.keys()) + + with open(output_file, "w") as f: + f.write("digraph cargo_dependencies {\n") + f.write(" rankdir=LR;\n") + f.write(" node [shape=box, style=rounded];\n") + f.write(' labelloc="t";\n') + f.write(f' label="{title}";\n') + f.write(" \n") + + # Define nodes + f.write(" // Nodes\n") + for crate_name in sorted(all_crates): + f.write(f' "{crate_name}";\n') + + f.write(" \n") + + # Define edges + f.write(" // Dependencies\n") + for crate, deps in sorted(dependencies.items()): + for dep_name, dep_type in sorted(deps): + # Use different colors/styles for different dependency types + if dep_type == "dev": + style = " [color=blue, style=dashed]" + elif dep_type == "build": + style = " [color=green, style=dotted]" + else: + style = "" + + f.write(f' "{crate}" -> "{dep_name}"{style};\n') + + f.write("}\n") + + print(f"Generated {output_file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate a Graphviz dependency graph from Rust Cargo.toml files", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate graph for current directory (includes build-dependencies by default) + %(prog)s + + # Generate graph for specific directory + %(prog)s /path/to/rust/project + + # Include dev dependencies and exclude build dependencies + %(prog)s --include-dev --exclude-build + + # Custom output file + %(prog)s -o my_deps.dot + + # Enable verbose output for debugging + %(prog)s -v + """, + ) + + parser.add_argument( + "directory", + nargs="?", + default=".", + help="Root directory to scan for Cargo.toml files (default: current directory)", + ) + + parser.add_argument( + "-o", + "--output", + default="cargo_deps.dot", + help="Output .dot file path (default: cargo_deps.dot)", + ) + + parser.add_argument( + "--include-dev", + action="store_true", + help="Include dev-dependencies in the graph (shown as dashed blue lines)", + ) + + parser.add_argument( + "--exclude-build", + action="store_true", + help="Exclude build-dependencies from the graph (by default they are included as dotted green lines)", + ) + + parser.add_argument( + "--title", + default="Cargo Dependencies", + help='Graph title (default: "Cargo Dependencies")', + ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose output for debugging", + ) + + args = parser.parse_args() + + root_dir = Path(args.directory).resolve() + + if not root_dir.exists(): + print(f"Error: Directory {root_dir} does not exist", file=sys.stderr) + sys.exit(1) + + print(f"Scanning {root_dir} for Cargo.toml files...") + + # Extract dependencies + crate_map, dependencies = extract_dependencies( + root_dir, + include_dev=args.include_dev, + include_build=not args.exclude_build, + verbose=args.verbose, + ) + + print(f"Found {len(crate_map)} crates") + + # Filter to only local dependencies + local_deps = filter_local_dependencies(crate_map, dependencies) + + print(f"Found {sum(len(deps) for deps in local_deps.values())} local dependencies") + + # Generate graphviz file + output_path = Path(args.output) + generate_graphviz(crate_map, local_deps, output_path, args.title) + + print("\nTo visualize the graph, run:") + print(f" dot -Tpng {output_path} -o {output_path.stem}.png") + print(f" dot -Tsvg {output_path} -o {output_path.stem}.svg") + print(f" dot -Tpdf {output_path} -o {output_path.stem}.pdf") + + +if __name__ == "__main__": + main()