Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,17 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
command = [app.roles[0].entrypoint] + app.roles[0].args
# Use Ray template from executor configuration
ray_template_name = executor.ray_template
ray_template_dir = None
if os.path.isabs(ray_template_name) or os.path.dirname(ray_template_name):
ray_template_name = os.path.basename(ray_template_name)
ray_template_dir = os.path.dirname(os.path.abspath(executor.ray_template))
req = SlurmRayRequest(
name=app.roles[0].name,
launch_cmd=["sbatch", "--requeue", "--parsable"],
command=" ".join(command),
cluster_dir=os.path.join(executor.tunnel.job_dir, Path(job_dir).name, "ray"),
template_name=ray_template_name,
template_dir=ray_template_dir,
executor=executor,
workdir=f"/{RUNDIR_NAME}/code",
nemo_run_dir=os.path.join(executor.tunnel.job_dir, Path(job_dir).name),
Expand Down
18 changes: 18 additions & 0 deletions test/run/torchx_backend/schedulers/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,9 @@ def test_ray_template_executor(slurm_scheduler, slurm_executor, temp_dir):
roles=[Role(name="test_role", image="", entrypoint="python", args=["script.py"])],
metadata={USE_WITH_RAY_CLUSTER_KEY: True},
)
custom_template_path = os.path.join(temp_dir, "custom_ray.sub.j2")
with open(custom_template_path, "w", encoding="utf-8") as f:
f.write("#!/bin/bash\n# Custom template")

with (
mock.patch.object(SlurmTunnelScheduler, "_initialize_tunnel"),
Expand Down Expand Up @@ -533,6 +536,21 @@ def test_ray_template_executor(slurm_scheduler, slurm_executor, temp_dir):
assert isinstance(dryrun_info.request, SlurmRayRequest)
assert dryrun_info.request.template_name == "ray_enroot.sub.j2"

path_executor = SlurmExecutor(
account="test_account",
job_dir=temp_dir,
nodes=1,
ntasks_per_node=1,
tunnel=LocalTunnel(job_dir=temp_dir),
ray_template=custom_template_path,
)
with mock.patch("nemo_run.core.execution.utils.fill_template") as mock_fill:
mock_fill.return_value = "#!/bin/bash\n# Mock script"
dryrun_info = slurm_scheduler._submit_dryrun(app_def, path_executor)
assert isinstance(dryrun_info.request, SlurmRayRequest)
assert dryrun_info.request.template_name == "custom_ray.sub.j2"
assert dryrun_info.request.template_dir == temp_dir


def test_heterogeneous_ray_cluster_run_as_group(slurm_scheduler, temp_dir):
"""Test that run_as_group is automatically set for heterogeneous Ray clusters."""
Expand Down
Loading