Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions monarch_extension/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,21 +330,6 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
to_stream: p.parseStreamRef("to_stream")?,
})
});
m.insert(key("CreatePipe"), |p| {
let function = p.parseFunction("function")?;
let args = p.parse("args")?;
let kwargs = p.parse("kwargs")?;
let (args, kwargs) = func_call_args_to_wire_values(Some(&function), &args, &kwargs)?;
Ok(WorkerMessage::CreatePipe {
result: p.parseRef("result")?,
key: p.parse("key")?,
function,
max_messages: p.parse("max_messages")?,
mesh: p.parseRef("device_mesh")?,
args,
kwargs,
})
});
m.insert(key("SendValue"), |p| {
let function = p.parseOptionalFunction("function")?;
let args: Bound<'_, PyTuple> = p.parse("args")?;
Expand Down
123 changes: 13 additions & 110 deletions monarch_messages/src/wire_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,7 @@ use crate::worker::ResolvableFunction;
// out for refs. And IValue is the same as RValue, but with real tensors and
// C++ types. I wonder if there is a nicer way to express this relationship.
// TODO extend this to support other types of values, like bytes, dicts etc.
#[derive(
Serialize,
Deserialize,
Debug,
Clone,
TryInto,
Named,
From,
EnumAsInner
)]
#[derive(Serialize, Deserialize, Debug, Clone, TryInto, Named, From)]
pub enum WireValue {
// Make sure boolean goes ealier than int as bool is a subclass of int.
// Otherwise, bool will be converted to int.
Expand Down Expand Up @@ -165,116 +156,28 @@ impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
}
}

impl From<PyObject> for WireValue {
fn from(obj: PyObject) -> Self {
Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap()))
}
}
impl<'py> IntoPyObject<'py> for WireValue {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
type Error = PyErr;

impl WireValue {
fn from_pyobject_with_torch_op_arg_type(
obj: Bound<'_, PyAny>,
type_: &torch_sys::call_op::TypePtr,
num_elements: i32,
allow_nums_as_tensors: bool,
) -> PyResult<Self> {
if type_.is_tensor() || type_.is_optional_tensor() {
if type_.is_optional_tensor() && obj.is_none() {
return Ok(WireValue::None(()));
} else if let Ok(ref_) = Ref::from_py_object(&obj) {
return Ok(WireValue::Ref(ref_));
}
}
if type_.is_tensor_list() || type_.is_optional_tensor_list() {
if type_.is_optional_tensor_list() && obj.is_none() {
return Ok(WireValue::None(()));
}
let list = obj.downcast::<PyList>()?;
let len = list.len();
if len == 0 {
return Ok(WireValue::RefList(vec![]));
}
// SAFETY: We know it is within bounds
let item = unsafe { list.get_item_unchecked(0) };
if let Ok(ref_) = Ref::from_py_object(&item) {
let mut ref_list = Vec::with_capacity(len);
ref_list.push(ref_);
for item in list.iter().skip(1) {
ref_list.push(Ref::from_py_object(&item).map_err(|_| {
PyValueError::new_err(format!(
"Expected homogeneous list of refs got: {:?}",
list
))
})?);
}
return Ok(WireValue::RefList(ref_list));
}
}
OpaqueIValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
.map(WireValue::IValue)
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
unsafe { self.try_to_object_unsafe(py) }
}
}

pub fn func_call_args_to_wire_values(
func: Option<&ResolvableFunction>,
args: &Bound<'_, PyTuple>,
kwargs: &Bound<'_, PyDict>,
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
if let Some((op, overload)) = func.and_then(|func| func.as_torch_op()) {
torch_op_args_to_wire_values(&op, &overload, args, kwargs)
} else {
python_func_args_to_wire_value(args, kwargs)
impl From<PyObject> for WireValue {
fn from(obj: PyObject) -> Self {
Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap()))
}
}

fn torch_op_args_to_wire_values(
op: &str,
overload: &str,
pub fn func_call_args_to_wire_values(
_func: Option<&ResolvableFunction>,
args: &Bound<'_, PyTuple>,
kwargs: &Bound<'_, PyDict>,
) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
let args_info = torch_sys::call_op::get_schema_args_info(op, overload).map_err(|err| {
PyValueError::new_err(format!(
"Failed to get the operator schema for {}::{}: {}",
op, overload, err
))
})?;

let args = args
.iter()
.zip(&args_info)
.map(|(arg, arg_info)| {
WireValue::from_pyobject_with_torch_op_arg_type(
arg,
arg_info.type_,
arg_info.num_elements,
arg_info.allows_number_as_tensor,
)
})
.collect::<Result<Vec<_>, _>>()?;
let kwargs = kwargs
.iter()
.map(|(k, v)| {
let key = k.extract::<String>()?;
let arg_info = args_info
.iter()
.find(|arg_info| arg_info.name == key)
.ok_or_else(|| {
PyValueError::new_err(format!(
"Torch op {}::{} does not support kwarg {}",
op, overload, key
))
})?;
let val = WireValue::from_pyobject_with_torch_op_arg_type(
v,
arg_info.type_,
arg_info.num_elements,
arg_info.allows_number_as_tensor,
)?;
Ok((key, val))
})
.collect::<Result<HashMap<_, _>, PyErr>>()?;
Ok((args, kwargs))
python_func_args_to_wire_value(args, kwargs)
}

fn python_func_args_to_wire_value(
Expand Down
32 changes: 0 additions & 32 deletions monarch_messages/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,6 @@ impl ResolvableFunction {
}
}

pub fn as_torch_op<'a>(&'a self) -> Option<(String, String)> {
match self {
Self::FunctionPath(func) => match func.path.split(".").collect::<Vec<_>>().as_slice() {
["torch", "ops", namespace, op_name, "default"] => {
Some((format!("{}::{}", namespace, op_name), String::new()))
}
["torch", "ops", namespace, op_name, overload] => {
Some((format!("{}::{}", namespace, op_name), overload.to_string()))
}
_ => None,
},
_ => None,
}
}

/// For testing: this is a special remote function path that induces a panic
/// when called.
pub fn panic_if_requested(&self) {
Expand All @@ -367,13 +352,6 @@ impl ResolvableFunction {
_ => (),
}
}

pub fn supports_pytree_args(&self) -> bool {
match self {
Self::Cloudpickle(_) => true,
Self::FunctionPath(_) => self.as_torch_op().is_none(),
}
}
}

impl<T: Into<String>> From<T> for ResolvableFunction {
Expand Down Expand Up @@ -800,16 +778,6 @@ pub enum WorkerMessage {
to_stream: StreamRef,
},

CreatePipe {
result: Ref,
key: String,
function: ResolvableFunction,
max_messages: i64,
mesh: Ref,
args: Vec<WireValue>,
kwargs: HashMap<String, WireValue>,
},

SendValue {
seq: Seq,
/// Pipe to send value to. If `None`, value is sent to controller.
Expand Down
2 changes: 1 addition & 1 deletion monarch_tensor_worker/src/borrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ mod tests {
.err()
.context("expected error")?;
assert!(
error.contains("torch operator error"),
error.contains("failed to resolve function"),
"If a borrowed value contains an error, downstream calls should propagate that error (unexpected error string: {})",
error,
);
Expand Down
Loading