From f2ea6b4ddaedf474a5e65ebe644c6cdbd1b14d22 Mon Sep 17 00:00:00 2001 From: a120092009 Date: Tue, 11 Nov 2025 14:48:57 +0800 Subject: [PATCH] Fix Context Parallelism doc --- docs/source/en/training/distributed_inference.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index f9756e1a67aa..cfd397f4aba3 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -253,7 +253,7 @@ try: device = torch.device("cuda", rank % torch.cuda.device_count()) torch.cuda.set_device(device) - transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2)) + transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2), device_map="cuda") pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda") pipeline.transformer.set_attention_backend("flash") @@ -289,4 +289,4 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`]. ```py pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2)) -``` \ No newline at end of file +```