Skip to content

Commit 5a206a1

Browse files
committed
fixing "stop" and adding test to verify multithreading
1 parent f907be0 commit 5a206a1

File tree

3 files changed

+85
-21
lines changed

3 files changed

+85
-21
lines changed

src/lib.rs

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,6 @@ pub(crate) struct PyCommand {
3737
responder: oneshot::Sender<Result<Value, String>>,
3838
}
3939

40-
/// A boxed, send-able future that resolves to a PyRunnerResult.
41-
type Task = Box<dyn FnOnce(&Runtime) -> Result<Value, PyRunnerError> + Send>;
42-
43-
/// A lazily-initialized worker thread for handling synchronous function calls.
44-
/// This thread has its own private Tokio runtime to safely block on async operations
45-
/// without interfering with any existing runtime the user might be in.
46-
static SYNC_WORKER: Lazy<std_mpsc::Sender<Task>> = Lazy::new(|| {
47-
let (tx, rx) = std_mpsc::channel::<Task>();
48-
49-
thread::spawn(move || {
50-
let rt = Runtime::new().expect("Failed to create Tokio runtime for sync worker");
51-
// When the sender (tx) is dropped, rx.recv() will return an Err, ending the loop.
52-
while let Ok(task) = rx.recv() {
53-
let _ = task(&rt); // The result is sent back via a channel inside the task.
54-
}
55-
});
56-
tx
57-
});
5840
/// Custom error types for the `PyRunner`.
5941
#[derive(Error, Debug, Clone)]
6042
pub enum PyRunnerError {
@@ -86,6 +68,27 @@ pub fn print_path_for_python(path: &PathBuf) -> String {
8668
}
8769
}
8870

71+
/// A boxed, send-able future that resolves to a PyRunnerResult.
72+
type Task = Box<dyn FnOnce(&Runtime) -> Result<Value, PyRunnerError> + Send>;
73+
74+
/// A lazily-initialized worker thread for handling synchronous function calls
75+
/// to functions that otherwise return a future. It will only be engaged when
76+
/// calling `..._sync()` functions of PyRunner.
77+
/// This thread has its own private Tokio runtime to safely block on async operations
78+
/// without interfering with any existing runtime the user might be in.
79+
static SYNC_WORKER: Lazy<std_mpsc::Sender<Task>> = Lazy::new(|| {
80+
let (tx, rx) = std_mpsc::channel::<Task>();
81+
82+
thread::spawn(move || {
83+
let rt = Runtime::new().expect("Failed to create Tokio runtime for sync worker");
84+
// When the sender (tx) is dropped, rx.recv() will return an Err, ending the loop.
85+
while let Ok(task) = rx.recv() {
86+
let _ = task(&rt); // The result is sent back via a channel inside the task.
87+
}
88+
});
89+
tx
90+
});
91+
8992
/// Manages a dedicated thread for executing Python code asynchronously.
9093
#[derive(Clone)]
9194
pub struct PyRunner {
@@ -734,4 +737,59 @@ result = mymodule.my_func()
734737

735738
assert_eq!(result, Value::String("hello from venv".to_string()));
736739
}
737-
}
740+
741+
#[tokio::test]
742+
async fn test_pyrunner_thread_safety() {
743+
let runner = PyRunner::new();
744+
runner.run("x = 0").await.unwrap();
745+
746+
let mut handles = vec![];
747+
748+
for i in 0..5 {
749+
let runner_clone = runner.clone();
750+
let handle = tokio::spawn(async move {
751+
// Each task increments 'x' in the shared Python interpreter
752+
runner_clone.run(&format!("x += {}", i)).await.unwrap();
753+
});
754+
handles.push(handle);
755+
}
756+
757+
// Wait for all tasks to complete
758+
for handle in handles {
759+
handle.await.unwrap();
760+
}
761+
762+
let final_x = runner.read_variable("x").await.unwrap();
763+
// Expected: 0 + 1 + 2 + 3 + 4 = 10
764+
assert_eq!(final_x, Value::Number(10.into()));
765+
}
766+
767+
#[test]
768+
fn test_pyrunner_std_thread_safety() {
769+
let runner = PyRunner::new();
770+
runner.run_sync("x = 0").unwrap();
771+
772+
let mut handles = vec![];
773+
774+
for i in 0..5 {
775+
let runner_clone = runner.clone();
776+
// Spawn a new OS thread
777+
let handle = thread::spawn(move || {
778+
// Use the _sync version for non-async contexts
779+
runner_clone.run_sync(&format!("x += {}", i)).unwrap();
780+
});
781+
handles.push(handle);
782+
}
783+
784+
// Wait for all threads to complete
785+
for handle in handles {
786+
handle.join().unwrap();
787+
}
788+
789+
let final_x = runner.read_variable_sync("x").unwrap();
790+
// Expected: 0 + 1 + 2 + 3 + 4 = 10
791+
assert_eq!(final_x, Value::Number(10.into()));
792+
793+
runner.stop_sync().unwrap();
794+
}
795+
}

src/pyo3_runner.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ pub(crate) async fn python_thread_main(mut receiver: mpsc::Receiver<PyCommand>)
5454
Err(e) => Err(e),
5555
}
5656
}
57-
CmdType::Stop => return receiver.close(),
57+
CmdType::Stop => {
58+
receiver.close();
59+
Ok(Value::Null)
60+
}
5861
};
5962

6063
// Convert PyErr to a string representation to avoid exposing it outside this module.

src/rustpython_runner.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ pub(crate) fn python_thread_main(mut receiver: mpsc::Receiver<PyCommand>) {
5151
dbg!(name, args);
5252
unimplemented!("Async functions are not supported yet in RustPython")
5353
}
54-
CmdType::Stop => break,
54+
CmdType::Stop => {
55+
receiver.close();
56+
Ok(Value::Null)
57+
}
5558
};
5659
let response = result.map_err(|err| {
5760
format!(

0 commit comments

Comments
 (0)