From c5980103542ecf4b3ac62e4c596c05e66d261b83 Mon Sep 17 00:00:00 2001 From: sippycoder Date: Tue, 20 May 2025 16:34:30 +0000 Subject: [PATCH] slurm scripts --- apps/Castor/modules/component.py | 21 +++- diagnose_flash_attn.py | 144 +++++++++++++++++++++++++++ setup_shared_env.sh | 165 +++++++++++++++++++++++++++++++ submit_pretrain.sh | 43 ++++++++ train_castor.slurm | 130 ++++++++++++++++++++++++ 5 files changed, 501 insertions(+), 2 deletions(-) create mode 100755 diagnose_flash_attn.py create mode 100755 setup_shared_env.sh create mode 100644 submit_pretrain.sh create mode 100755 train_castor.slurm diff --git a/apps/Castor/modules/component.py b/apps/Castor/modules/component.py index 029c87d..5456883 100644 --- a/apps/Castor/modules/component.py +++ b/apps/Castor/modules/component.py @@ -16,8 +16,25 @@ from liger_kernel.transformers import LigerSwiGLUMLP, LigerRMSNorm, liger_rotary_pos_emb from types import SimpleNamespace -# fa3 -from flash_attn_interface import flash_attn_varlen_func +# fa3 - try multiple import paths for better compatibility +try: + # First try the normal path that should work with flash-attn package + from flash_attn.flash_attn_interface import flash_attn_varlen_func +except ImportError: + try: + # Try the compatibility layer + from flash_attn.flash_attn_interface import flash_attn_varlen_func + except ImportError: + # Create a dummy placeholder that will raise a more descriptive error when used + def flash_attn_varlen_func(*args, **kwargs): + raise RuntimeError( + "flash_attn_varlen_func could not be imported. " + "Please install flash-attention properly or check the import paths." + ) + warnings.warn( + "flash_attn_varlen_func could not be imported from either " + "flash_attn.flash_attn_interface or flash_attn_interface. Using dummy placeholder." + ) flex_attention_comp = torch.compile(flex_attention) diff --git a/diagnose_flash_attn.py b/diagnose_flash_attn.py new file mode 100755 index 0000000..ccf37be --- /dev/null +++ b/diagnose_flash_attn.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +""" +Diagnostic script for flash-attention issues +This script helps identify issues with flash-attention installation and imports. +""" + +import os +import sys +import importlib.util +import subprocess +from pathlib import Path + +def print_header(title): + print("\n" + "=" * 80) + print(f" {title} ".center(80, "=")) + print("=" * 80) + +def check_module_exists(module_name): + return importlib.util.find_spec(module_name) is not None + +def run_cmd(cmd): + print(f"Running: {cmd}") + try: + output = subprocess.check_output(cmd, shell=True, universal_newlines=True, stderr=subprocess.STDOUT) + return True, output + except subprocess.CalledProcessError as e: + return False, e.output + +def main(): + print_header("FLASH ATTENTION DIAGNOSTIC TOOL") + + # Check Python version + print(f"Python version: {sys.version}") + + # Check CUDA availability + print("\nCUDA Environment:") + cuda_home = os.environ.get("CUDA_HOME", "Not set") + print(f"CUDA_HOME: {cuda_home}") + + # Check for PyTorch installation + print("\nChecking PyTorch installation:") + if check_module_exists("torch"): + import torch + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA available: {torch.cuda.is_available()}") + if torch.cuda.is_available(): + print(f"CUDA version: {torch.version.cuda}") + print(f"Number of GPUs: {torch.cuda.device_count()}") + for i in range(torch.cuda.device_count()): + print(f" Device {i}: {torch.cuda.get_device_name(i)}") + else: + print("PyTorch is not installed") + + # Check flash attention installation + print("\nChecking flash-attention installation:") + if check_module_exists("flash_attn"): + try: + import flash_attn + print(f"flash-attn version: {flash_attn.__version__}") + + # Check if the interface module exists + print("\nChecking flash_attn.flash_attn_interface:") + if check_module_exists("flash_attn.flash_attn_interface"): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + print("flash_attn.flash_attn_interface.flash_attn_varlen_func is available") + else: + print("flash_attn.flash_attn_interface module not found") + + # Check compatibility layer + print("\nChecking flash_attn_interface compatibility layer:") + if check_module_exists("flash_attn_interface"): + import flash_attn_interface + if hasattr(flash_attn_interface, "flash_attn_varlen_func"): + print("flash_attn_interface.flash_attn_varlen_func is available") + else: + print("flash_attn_interface module exists but flash_attn_varlen_func is not available") + else: + print("flash_attn_interface compatibility layer not found") + + # Check flash-attention module file locations + flash_attn_path = Path(flash_attn.__file__).parent + print(f"\nflash-attn installation path: {flash_attn_path}") + + interface_files = list(flash_attn_path.glob("*interface*")) + if interface_files: + print("Interface files found:") + for f in interface_files: + print(f" {f}") + else: + print("No interface files found in flash-attn directory") + + except ImportError as e: + print(f"Error importing flash-attn modules: {e}") + else: + print("flash-attn is not installed") + + # Check for common issues + print_header("CHECKING FOR COMMON ISSUES") + + # Check Castor component.py import + castor_component = "apps/Castor/modules/component.py" + if os.path.exists(castor_component): + print(f"\nChecking {castor_component} for import issues:") + success, output = run_cmd(f"grep -n 'flash_attn' {castor_component} | head -5") + if success: + print(output) + else: + print(f"Error checking {castor_component}: {output}") + else: + print(f"\n{castor_component} not found") + + # Run pip check to look for dependency issues + print("\nChecking for dependency issues:") + success, output = run_cmd("pip check") + if success: + print("No dependency issues found") + else: + print("Dependency issues found:") + print(output) + + # Provide recommendations + print_header("RECOMMENDATIONS") + print(""" +1. If flash-attn is not installed: + - Run: pip install flash-attn --no-build-isolation + +2. If the compatibility layer is missing: + - Create a file at site-packages/flash_attn_interface.py with: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + +3. If Castor's component.py has the wrong import: + - Update it to use: from flash_attn.flash_attn_interface import flash_attn_varlen_func + - Or create the compatibility layer as in step 2 + +4. If there are CUDA errors: + - Check CUDA version compatibility with installed PyTorch + - Check if CUDA_HOME is set correctly + +For more detailed diagnosis, check the flash-attention repository at: +https://github.com/Dao-AILab/flash-attention + """) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/setup_shared_env.sh b/setup_shared_env.sh new file mode 100755 index 0000000..99092dc --- /dev/null +++ b/setup_shared_env.sh @@ -0,0 +1,165 @@ +#!/bin/bash +set -e + +ROOT_DIR=/mnt/pollux + +# Create shared directories +mkdir -p $ROOT_DIR/environments +mkdir -p $ROOT_DIR/compiled_packages + +# Setup conda environment in shared location +if [ ! -d "$ROOT_DIR/environments/pollux_env" ]; then + echo "Creating conda environment in shared location..." + conda create -y -p $ROOT_DIR/environments/pollux_env python=3.12.9 + + # Activate the environment and install PyTorch and other dependencies + source $(conda info --base)/etc/profile.d/conda.sh + conda activate $ROOT_DIR/environments/pollux_env + + # Install PyTorch with CUDA 12.8 + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 + + # Install xformers and other basic dependencies + pip install xformers ninja packaging + pip install --requirement requirements.txt + + # Install CLIP + pip install git+https://github.com/openai/CLIP.git + + # Install optional dependencies + pip install timm torchmetrics +else + echo "Conda environment already exists at $ROOT_DIR/environments/pollux_env" + source $(conda info --base)/etc/profile.d/conda.sh + conda activate $ROOT_DIR/environments/pollux_env +fi + +# Function to create the compatibility layer +create_compatibility_layer() { + echo "Creating flash_attn_interface compatibility module..." + cat > $ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/flash_attn_interface.py << EOF +# Compatibility layer for flash_attn_interface imports +from flash_attn.flash_attn_interface import flash_attn_varlen_func +# Export other functions that might be needed +from flash_attn.flash_attn_interface import _flash_attn_varlen_forward, _flash_attn_varlen_backward +# Export any other necessary functions +EOF +} + +# Build Flash Attention v3 if not already built +if [ ! -f "$ROOT_DIR/compiled_packages/flash-attention/hopper/build/lib.linux-x86_64-cpython-312/flash_attn_3.so" ]; then + echo "Compiling Flash Attention v3..." + cd $ROOT_DIR/compiled_packages + + # Configure DNS servers as a fallback in case of DNS issues + export DNS_BACKUP="8.8.8.8 8.8.4.4 1.1.1.1" + echo "Setting up DNS fallback to: $DNS_BACKUP" + cat > /tmp/resolv.conf.new << EOF +nameserver 8.8.8.8 +nameserver 8.8.4.4 +nameserver 1.1.1.1 +EOF + export HOSTALIASES=/tmp/resolv.conf.new + + # First try direct pip install as fallback (might be easier than building from source) + echo "Trying direct pip install of flash-attention..." + if pip install flash-attn --no-build-isolation; then + echo "Successfully installed flash-attention via pip" + create_compatibility_layer + echo "Flash Attention installed via pip" + else + echo "Pip install failed, attempting to build from source..." + + # Retry git clone with exponential backoff + MAX_RETRIES=5 + retry_count=0 + clone_success=false + + while [ $retry_count -lt $MAX_RETRIES ] && [ "$clone_success" = false ]; do + if [ ! -d "flash-attention" ]; then + echo "Attempt $(($retry_count + 1))/$MAX_RETRIES: Cloning flash-attention repository..." + if git clone https://github.com/Dao-AILab/flash-attention.git; then + clone_success=true + cd flash-attention + git checkout v2.7.4.post1 + else + retry_count=$((retry_count + 1)) + if [ $retry_count -lt $MAX_RETRIES ]; then + sleep_time=$((2 ** retry_count)) + echo "Clone failed. Retrying in $sleep_time seconds..." + sleep $sleep_time + else + echo "Failed to clone after $MAX_RETRIES attempts." + exit 1 + fi + fi + else + clone_success=true + cd flash-attention + git checkout v2.7.4.post1 + fi + done + + # Try a different approach if the main build method fails + cd hopper/ + echo "Building Flash Attention with MAX_JOBS=24..." + if ! MAX_JOBS=24 python setup.py build; then + echo "Default build method failed, trying alternative approach..." + # Try alternative build approach + if ! TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" pip install -e ..; then + echo "Alternative build method failed, trying pip install again with force..." + cd ../../ + pip install flash-attn --force-reinstall + fi + fi + + # Copy built libraries to the environment if available + if [ -d "build/lib.linux-x86_64-cpython-312" ]; then + echo "Copying built Flash Attention libraries to Python environment..." + cp -r build/lib.linux-x86_64-cpython-312/* $ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/ + fi + + create_compatibility_layer + echo "Flash Attention v3 has been compiled and installed to the shared environment" + fi +else + echo "Flash Attention v3 is already compiled" + + # Always ensure the compatibility layer exists + if [ ! -f "$ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/flash_attn_interface.py" ]; then + create_compatibility_layer + fi +fi + +# Install COSMOS Tokenizer VAE +if [ ! -d "$ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/cosmos_tokenizer" ]; then + echo "Installing COSMOS Tokenizer VAE..." + cd $(dirname "$0") # Go to the directory where this script is located + + # Check if Cosmos-Tokenizer directory exists + if [ -d "apps/Cosmos-Tokenizer" ]; then + cd apps/Cosmos-Tokenizer + pip install -e . + cd ../.. + echo "COSMOS Tokenizer VAE installed." + else + echo "Warning: apps/Cosmos-Tokenizer directory not found, skipping installation." + echo "If you need COSMOS Tokenizer, please make sure the repository is properly cloned." + # Optional: Clone the repository if it doesn't exist + # git clone https://github.com/your-org/Cosmos-Tokenizer.git apps/Cosmos-Tokenizer + fi +fi + +# Final verification of Flash Attention installation +python -c " +try: + import flash_attn + print(f'Flash Attention {flash_attn.__version__} successfully installed') + import flash_attn_interface + print('Flash Attention Interface compatibility layer is working') +except ImportError as e: + print(f'Error importing Flash Attention: {e}') + exit(1) +" + +echo "Environment setup complete. Activate with: conda activate $ROOT_DIR/environments/pollux_env" diff --git a/submit_pretrain.sh b/submit_pretrain.sh new file mode 100644 index 0000000..e6c4b49 --- /dev/null +++ b/submit_pretrain.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Helper script to submit and monitor SLURM jobs + +# Make sure scripts are executable +chmod +x setup_shared_env.sh +chmod +x train_castor.slurm +chmod +x diagnose_flash_attn.py + +# Check if diagnostic mode is requested +if [ "$1" == "diagnose" ]; then + echo "Running flash-attention diagnostic tool..." + source $(conda info --base)/etc/profile.d/conda.sh + conda activate /mnt/pollux/environments/pollux_env + python diagnose_flash_attn.py + exit 0 +fi + +# Check if a custom partition is provided +if [ -n "$1" ] && [ "$1" != "diagnose" ]; then + PARTITION="$1" + echo "Using custom partition: $PARTITION" + # Submit the job with the specified partition + JOB_ID=$(sbatch --parsable --partition="$PARTITION" train_castor.slurm) +else + # Submit the job with the default partition in the script + JOB_ID=$(sbatch --parsable train_castor.slurm) +fi + +echo "Job submitted with ID: $JOB_ID" +echo "Monitor with: squeue -j $JOB_ID" +echo "View logs with: tail -f ${JOB_ID}.out" +echo "" +echo "Quick commands:" +echo "---------------" +echo "View job status: scontrol show job $JOB_ID" +echo "Cancel job: scancel $JOB_ID" +echo "View resource usage: sstat --format=AveCPU,AveRSS,AveVMSize,MaxRSS,MaxVMSize -j $JOB_ID" +echo "Run diagnostic tool: ./submit_job.sh diagnose" + +# Watch the job status +echo "" +echo "Watching job status (press Ctrl+C to exit):" +watch -n 10 squeue -j "$JOB_ID" \ No newline at end of file diff --git a/train_castor.slurm b/train_castor.slurm new file mode 100755 index 0000000..4164142 --- /dev/null +++ b/train_castor.slurm @@ -0,0 +1,130 @@ +#!/bin/bash +#SBATCH --job-name=castor_train # Job name +#SBATCH --nodes=1 # Number of nodes +#SBATCH --ntasks-per-node=8 # Number of tasks per node (1 per GPU) +#SBATCH --gpus-per-node=8 # Request 8 GPUs per node (H100s) +#SBATCH --cpus-per-task=8 # CPU cores per task +#SBATCH --mem=0 # Request all memory on the node +#SBATCH --time=72:00:00 # Time limit (72 hours) +#SBATCH --output=%x_%j.out # Standard output log +#SBATCH --error=%x_%j.err # Standard error log +#SBATCH --partition=debug # Partition with H100 GPUs (adjust if needed) + +# Set up environment variables for CUDA +export CUDA_HOME=/usr/local/cuda-12.8 +export PATH=$CUDA_HOME/bin:$PATH +export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH +export ROOT_DIR=/mnt/pollux + +# Configure DNS servers as a fallback in case of DNS issues +export DNS_BACKUP="8.8.8.8 8.8.4.4 1.1.1.1" +echo "Setting up DNS fallback to: $DNS_BACKUP" +cat > /tmp/resolv.conf.new << EOF +nameserver 8.8.8.8 +nameserver 8.8.4.4 +nameserver 1.1.1.1 +EOF +export HOSTALIASES=/tmp/resolv.conf.new + +# Check if the shared environment exists and set it up if not +if [ "$SLURM_PROCID" == "0" ]; then + echo "Checking if shared environment setup is needed..." + bash setup_shared_env.sh + + # Verify flash_attn_interface.py exists + if [ ! -f "$ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/flash_attn_interface.py" ]; then + echo "Creating flash_attn_interface compatibility module..." + cat > $ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/flash_attn_interface.py << EOF +# Compatibility layer for flash_attn_interface imports +from flash_attn.flash_attn_interface import flash_attn_varlen_func +# Export other functions that might be needed +from flash_attn.flash_attn_interface import _flash_attn_varlen_forward, _flash_attn_varlen_backward +# Export any other necessary functions +EOF + fi + + # Patch the Castor component.py file to fix flash_attn import if needed + COMPONENT_FILE="apps/Castor/modules/component.py" + if [ -f "$COMPONENT_FILE" ]; then + echo "Checking if Castor component.py needs patching..." + if grep -q "from flash_attn_interface import" "$COMPONENT_FILE"; then + echo "Patching $COMPONENT_FILE to use correct flash_attn import..." + # Create backup + cp "$COMPONENT_FILE" "${COMPONENT_FILE}.bak" + # Replace direct import with proper module path + sed -i 's/from flash_attn_interface import flash_attn_varlen_func/from flash_attn.flash_attn_interface import flash_attn_varlen_func/' "$COMPONENT_FILE" + echo "Patched $COMPONENT_FILE" + else + echo "$COMPONENT_FILE doesn't need patching" + fi + fi +fi + +# Wait for rank 0 to finish setting up the environment +sleep 5 +srun --ntasks=1 --nodes=1 --ntasks-per-node=1 --wait=0 true + +# Load the shared conda environment +source $(conda info --base)/etc/profile.d/conda.sh +conda activate $ROOT_DIR/environments/pollux_env + +# Verify critical packages are installed +echo "Verifying critical packages..." +python -c " +try: + import torch + print(f'PyTorch version: {torch.__version__}') + import flash_attn + print(f'Flash Attention version: {flash_attn.__version__}') + + # Attempt to import from both paths to test compatibility + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + print('Original flash_attn.flash_attn_interface module is available') + except ImportError: + print('Warning: Could not import from flash_attn.flash_attn_interface') + + try: + import flash_attn_interface + print('Compatibility layer flash_attn_interface is available') + except ImportError: + print('Warning: Could not import flash_attn_interface') + +except ImportError as e: + print(f'Error: {e}') + print('Warning: Some required packages are missing but continuing anyway') +" + +# Setup distributed training environment variables +export PYTHONPATH=$PYTHONPATH:$(pwd) +export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE)) +export RANK=$SLURM_PROCID +export LOCAL_RANK=$SLURM_LOCALID + +# Print environment information +echo "MASTER_ADDR: $MASTER_ADDR" +echo "MASTER_PORT: $MASTER_PORT" +echo "WORLD_SIZE: $WORLD_SIZE" +echo "SLURM_PROCID: $SLURM_PROCID" +echo "SLURM_LOCALID: $SLURM_LOCALID" +echo "SLURM_NODEID: $SLURM_NODEID" + +# Set PyTorch options for better performance +export TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" +export TORCH_DISTRIBUTED_DEBUG=DETAIL +export TORCH_EXTENSIONS_DIR=$ROOT_DIR/torch_extensions # Share compiled extensions + +# NCCL settings for optimal performance +export NCCL_DEBUG=INFO +export NCCL_IB_DISABLE=0 +export NCCL_IB_HCA=mlx5 +export NCCL_NET_GDR_LEVEL=2 +export NCCL_SOCKET_IFNAME=^lo,docker0 + +# Set PyTorch to use high precision +export TORCH_FLOAT32_MATMUL_PRECISION=high + +# Launch training using srun for optimal resource allocation +srun python -m apps.Castor.train config=apps/Castor/configs/train_bucket_256_Castor_flux_qwen_fixed_siglip2.yaml