Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 19 additions & 195 deletions monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
mod borrow;
mod comm;
pub mod device_mesh;
pub mod pipe;
pub mod py_pipe;
pub mod stream;
pub mod test_util;

Expand Down Expand Up @@ -71,7 +69,6 @@ use monarch_messages::controller::Seq;
use monarch_messages::wire_value::WireValue;
use monarch_messages::worker::ActorCallParams;
use monarch_messages::worker::ActorMethodParams;
use monarch_messages::worker::CallFunctionError;
use monarch_messages::worker::CallFunctionParams;
use monarch_messages::worker::Factory;
use monarch_messages::worker::Reduction;
Expand All @@ -84,8 +81,6 @@ use monarch_messages::worker::WorkerMessageHandler;
use monarch_messages::worker::WorkerParams;
use monarch_types::PyTree;
use ndslice::Slice;
use pipe::PipeActor;
use pipe::PipeParams;
use pyo3::Python;
use pyo3::types::PyAnyMethods;
use serde::Deserialize;
Expand Down Expand Up @@ -173,8 +168,6 @@ pub struct WorkerActor {
borrows: HashMap<u64, Borrow>,
comm: Option<ActorHandle<NcclCommActor>>,
controller_actor: ActorRef<ControllerActor>,
/// Pipes created for the worker.
pipes: HashMap<Ref, ActorHandle<PipeActor>>,
/// Remember the process groups "created" via `CreateRemoteProcessGroup` for
/// subsequent `CallFunction` calls, as this is where the actual allocation
/// will happen.
Expand Down Expand Up @@ -244,7 +237,6 @@ impl Actor for WorkerActor {
borrows: HashMap::new(),
comm: None,
controller_actor,
pipes: HashMap::new(),
remote_process_groups: HashMap::new(),
send_recv_comms: HashMap::new(),
recordings: HashMap::new(),
Expand Down Expand Up @@ -648,47 +640,18 @@ impl WorkerMessageHandler for WorkerActor {

async fn create_pipe(
&mut self,
cx: &hyperactor::Context<Self>,
result: Ref,
_cx: &hyperactor::Context<Self>,
_result: Ref,
// TODO(agallagher): This is used in the python impl to name the socket
// path to use for comms, but we don't currently use a named socket.
_key: String,
function: ResolvableFunction,
max_messages: i64,
device_mesh: Ref,
args: Vec<WireValue>,
kwargs: HashMap<String, WireValue>,
_function: ResolvableFunction,
_max_messages: i64,
_device_mesh: Ref,
_args: Vec<WireValue>,
_kwargs: HashMap<String, WireValue>,
) -> Result<()> {
println!("CREATE PIPE1 {}", result);
let args: Vec<PyTree<RValue>> = args
.into_iter()
.map(|object| RValue::PyObject(object.into_py_object().unwrap()).into())
.collect();
let kwargs: HashMap<_, PyTree<RValue>> = kwargs
.into_iter()
.map(|(k, object)| (k, RValue::PyObject(object.into_py_object().unwrap()).into()))
.collect();
let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
})?;
println!("CREATE PIPE2 {}", result);
// TODO(agallagher): Fix error prop. (When pipe is read from the pipes dict if it had an error it should cause a dependent error in send_value not an actor error as it does now)
let pipe = PipeActor::spawn(
cx,
PipeParams {
function,
max_messages,
ranks: device_mesh.0.ranks(),
sizes: device_mesh.0.sizes(),
args,
kwargs,
},
)
.await?;
println!("AFTER CREATE PIPE {}", result);

self.pipes.insert(result, pipe);
Ok(())
panic!("create_pipe is no longer implemented")
}

async fn send_tensor(
Expand Down Expand Up @@ -818,18 +781,11 @@ impl WorkerMessageHandler for WorkerActor {
.collect()
};

let pipe = if let Some(destination) = destination {
let pipe = self
.pipes
.get(&destination)
.ok_or_else(|| anyhow::anyhow!("invalid pipe id: {:#?}", destination))?
.port();
Some(pipe)
} else {
None
};
// Resolve the value on the stream, then send the value to the pipe if provided,
// or back to the controller if not.
if destination.is_some() {
panic!("send_value with pipe destination is no longer implemented")
}

// Resolve the value on the stream, then send the value back to the controller.
stream
.send_value(
cx,
Expand All @@ -840,7 +796,6 @@ impl WorkerMessageHandler for WorkerActor {
args,
kwargs,
device_meshes,
pipe,
)
.await
}
Expand Down Expand Up @@ -971,24 +926,13 @@ impl WorkerMessageHandler for WorkerActor {

async fn pipe_recv(
&mut self,
cx: &hyperactor::Context<Self>,
seq: Seq,
results: Vec<Option<Ref>>,
pipe: Ref,
stream: StreamRef,
_cx: &hyperactor::Context<Self>,
_seq: Seq,
_results: Vec<Option<Ref>>,
_pipe: Ref,
_stream: StreamRef,
) -> Result<()> {
self.maybe_add_stream_to_recording(cx, stream).await?;

// Get a port for the pipe
let pipe = self
.pipes
.get(&pipe)
.ok_or_else(|| anyhow::anyhow!("ref not found: {}", pipe))?;
let pipe = pipe.port();
// Resolve the stream.
let stream = self.try_get_stream(stream)?;
// Push result into the stream.
stream.set_value(cx, seq, results, pipe).await
panic!("pipe_recv is no longer implemented")
}

async fn set_ref_unit_tests_only(
Expand Down Expand Up @@ -2186,126 +2130,6 @@ mod tests {
Ok(())
}

#[async_timed_test(timeout_secs = 60)]
async fn pipe_send_recv() -> Result<()> {
test_setup()?;

let proc = Proc::local();
let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();

let handle = proc
.spawn::<WorkerActor>(
"worker",
WorkerParams {
world_size: 1,
rank: 0,
device_index: None,
controller_actor: controller_ref,
},
)
.await
.unwrap();
let (resolve_value_arg, torch_eq_arg1, torch_eq_arg2): (
PickledPyObject,
PickledPyObject,
PickledPyObject,
) = Python::with_gil(|py| {
PyResult::Ok((
PyList::new(py, [2, 3])?.into_any().try_into()?,
Ref { id: 2 }.into_bound_py_any(py)?.try_into()?,
Ref { id: 4 }.into_bound_py_any(py)?.try_into()?,
))
})?;

handle
.command_group(
&client,
vec![
WorkerMessage::CreateStream {
id: 0.into(),
stream_creation: StreamCreationMode::UseDefaultStream,
},
WorkerMessage::CreateDeviceMesh {
result: 1.into(),
names: vec!["x".into()],
ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
},
// Create a tensor value which we'll send through the pipe.
WorkerMessage::CallFunction(CallFunctionParams {
seq: 0.into(),
results: vec![Some(2.into())],
mutates: vec![],
function: "torch.ops.aten.ones.default".into(),
args: vec![WireValue::IntList(vec![2, 3])],
kwargs: HashMap::new(),
stream: 0.into(),
remote_process_groups: vec![],
}),
WorkerMessage::CreatePipe {
result: 3.into(),
key: "unused".into(),
function: "monarch.monarch_tensor_worker.test_utils.handler".into(),
max_messages: 1,
mesh: 1.into(),
args: vec![],
kwargs: HashMap::new(),
},
WorkerMessage::SendValue {
seq: 1.into(),
destination: Some(3.into()),
mutates: vec![],
function: Some(
"monarch.monarch_tensor_worker.test_utils.resolve_value".into(),
),
args: vec![resolve_value_arg.into()],
kwargs: HashMap::new(),
stream: 0.into(),
},
WorkerMessage::PipeRecv {
seq: 2.into(),
results: vec![Some(4.into())],
pipe: 3.into(),
stream: 0.into(),
},
WorkerMessage::CallFunction(CallFunctionParams {
seq: 0.into(),
results: vec![Some(5.into())],
mutates: vec![],
function: "torch.equal".into(),
args: vec![torch_eq_arg1.into(), torch_eq_arg2.into()],
kwargs: HashMap::new(),
stream: 0.into(),
remote_process_groups: vec![],
}),
],
)
.await
.unwrap();

let matches: bool = handle
.get_ref_unit_tests_only(&client, 5.into(), 0.into())
.await
.unwrap()
.unwrap()
.unwrap()
.try_into()
.unwrap();
assert!(matches);

handle.drain_and_stop()?;
assert_matches!(handle.await, ActorStatus::Stopped);

let responses = controller_rx.drain();
assert_eq!(
responses.len(),
0,
"Expected one response, got: {:#?}",
responses
);

Ok(())
}

fn get_random_channel_addr() -> ChannelAddr {
let random_string = rand::thread_rng()
.sample_iter(&Alphanumeric)
Expand Down
Loading
Loading