Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions apps/Castor/modules/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,25 @@
from liger_kernel.transformers import LigerSwiGLUMLP, LigerRMSNorm, liger_rotary_pos_emb
from types import SimpleNamespace

# fa3
from flash_attn_interface import flash_attn_varlen_func
# fa3 - try multiple import paths for better compatibility
try:
# First try the normal path that should work with flash-attn package
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
try:
# Try the compatibility layer
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
# Create a dummy placeholder that will raise a more descriptive error when used
def flash_attn_varlen_func(*args, **kwargs):
raise RuntimeError(
"flash_attn_varlen_func could not be imported. "
"Please install flash-attention properly or check the import paths."
)
warnings.warn(
"flash_attn_varlen_func could not be imported from either "
"flash_attn.flash_attn_interface or flash_attn_interface. Using dummy placeholder."
)


flex_attention_comp = torch.compile(flex_attention)
Expand Down
144 changes: 144 additions & 0 deletions diagnose_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python
"""
Diagnostic script for flash-attention issues
This script helps identify issues with flash-attention installation and imports.
"""

import os
import sys
import importlib.util
import subprocess
from pathlib import Path

def print_header(title):
print("\n" + "=" * 80)
print(f" {title} ".center(80, "="))
print("=" * 80)

def check_module_exists(module_name):
return importlib.util.find_spec(module_name) is not None

def run_cmd(cmd):
print(f"Running: {cmd}")
try:
output = subprocess.check_output(cmd, shell=True, universal_newlines=True, stderr=subprocess.STDOUT)
return True, output
except subprocess.CalledProcessError as e:
return False, e.output

def main():
print_header("FLASH ATTENTION DIAGNOSTIC TOOL")

# Check Python version
print(f"Python version: {sys.version}")

# Check CUDA availability
print("\nCUDA Environment:")
cuda_home = os.environ.get("CUDA_HOME", "Not set")
print(f"CUDA_HOME: {cuda_home}")

# Check for PyTorch installation
print("\nChecking PyTorch installation:")
if check_module_exists("torch"):
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" Device {i}: {torch.cuda.get_device_name(i)}")
else:
print("PyTorch is not installed")

# Check flash attention installation
print("\nChecking flash-attention installation:")
if check_module_exists("flash_attn"):
try:
import flash_attn
print(f"flash-attn version: {flash_attn.__version__}")

# Check if the interface module exists
print("\nChecking flash_attn.flash_attn_interface:")
if check_module_exists("flash_attn.flash_attn_interface"):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
print("flash_attn.flash_attn_interface.flash_attn_varlen_func is available")
else:
print("flash_attn.flash_attn_interface module not found")

# Check compatibility layer
print("\nChecking flash_attn_interface compatibility layer:")
if check_module_exists("flash_attn_interface"):
import flash_attn_interface
if hasattr(flash_attn_interface, "flash_attn_varlen_func"):
print("flash_attn_interface.flash_attn_varlen_func is available")
else:
print("flash_attn_interface module exists but flash_attn_varlen_func is not available")
else:
print("flash_attn_interface compatibility layer not found")

# Check flash-attention module file locations
flash_attn_path = Path(flash_attn.__file__).parent
print(f"\nflash-attn installation path: {flash_attn_path}")

interface_files = list(flash_attn_path.glob("*interface*"))
if interface_files:
print("Interface files found:")
for f in interface_files:
print(f" {f}")
else:
print("No interface files found in flash-attn directory")

except ImportError as e:
print(f"Error importing flash-attn modules: {e}")
else:
print("flash-attn is not installed")

# Check for common issues
print_header("CHECKING FOR COMMON ISSUES")

# Check Castor component.py import
castor_component = "apps/Castor/modules/component.py"
if os.path.exists(castor_component):
print(f"\nChecking {castor_component} for import issues:")
success, output = run_cmd(f"grep -n 'flash_attn' {castor_component} | head -5")
if success:
print(output)
else:
print(f"Error checking {castor_component}: {output}")
else:
print(f"\n{castor_component} not found")

# Run pip check to look for dependency issues
print("\nChecking for dependency issues:")
success, output = run_cmd("pip check")
if success:
print("No dependency issues found")
else:
print("Dependency issues found:")
print(output)

# Provide recommendations
print_header("RECOMMENDATIONS")
print("""
1. If flash-attn is not installed:
- Run: pip install flash-attn --no-build-isolation

2. If the compatibility layer is missing:
- Create a file at site-packages/flash_attn_interface.py with:
from flash_attn.flash_attn_interface import flash_attn_varlen_func

3. If Castor's component.py has the wrong import:
- Update it to use: from flash_attn.flash_attn_interface import flash_attn_varlen_func
- Or create the compatibility layer as in step 2

4. If there are CUDA errors:
- Check CUDA version compatibility with installed PyTorch
- Check if CUDA_HOME is set correctly

For more detailed diagnosis, check the flash-attention repository at:
https://github.com/Dao-AILab/flash-attention
""")

if __name__ == "__main__":
main()
165 changes: 165 additions & 0 deletions setup_shared_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#!/bin/bash
set -e

ROOT_DIR=/mnt/pollux

# Create shared directories
mkdir -p $ROOT_DIR/environments
mkdir -p $ROOT_DIR/compiled_packages

# Setup conda environment in shared location
if [ ! -d "$ROOT_DIR/environments/pollux_env" ]; then
echo "Creating conda environment in shared location..."
conda create -y -p $ROOT_DIR/environments/pollux_env python=3.12.9

# Activate the environment and install PyTorch and other dependencies
source $(conda info --base)/etc/profile.d/conda.sh
conda activate $ROOT_DIR/environments/pollux_env

# Install PyTorch with CUDA 12.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

# Install xformers and other basic dependencies
pip install xformers ninja packaging
pip install --requirement requirements.txt

# Install CLIP
pip install git+https://github.com/openai/CLIP.git

# Install optional dependencies
pip install timm torchmetrics
else
echo "Conda environment already exists at $ROOT_DIR/environments/pollux_env"
source $(conda info --base)/etc/profile.d/conda.sh
conda activate $ROOT_DIR/environments/pollux_env
fi

# Function to create the compatibility layer
create_compatibility_layer() {
echo "Creating flash_attn_interface compatibility module..."
cat > $ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/flash_attn_interface.py << EOF
# Compatibility layer for flash_attn_interface imports
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Export other functions that might be needed
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward, _flash_attn_varlen_backward
# Export any other necessary functions
EOF
}

# Build Flash Attention v3 if not already built
if [ ! -f "$ROOT_DIR/compiled_packages/flash-attention/hopper/build/lib.linux-x86_64-cpython-312/flash_attn_3.so" ]; then
echo "Compiling Flash Attention v3..."
cd $ROOT_DIR/compiled_packages

# Configure DNS servers as a fallback in case of DNS issues
export DNS_BACKUP="8.8.8.8 8.8.4.4 1.1.1.1"
echo "Setting up DNS fallback to: $DNS_BACKUP"
cat > /tmp/resolv.conf.new << EOF
nameserver 8.8.8.8
nameserver 8.8.4.4
nameserver 1.1.1.1
EOF
export HOSTALIASES=/tmp/resolv.conf.new

# First try direct pip install as fallback (might be easier than building from source)
echo "Trying direct pip install of flash-attention..."
if pip install flash-attn --no-build-isolation; then
echo "Successfully installed flash-attention via pip"
create_compatibility_layer
echo "Flash Attention installed via pip"
else
echo "Pip install failed, attempting to build from source..."

# Retry git clone with exponential backoff
MAX_RETRIES=5
retry_count=0
clone_success=false

while [ $retry_count -lt $MAX_RETRIES ] && [ "$clone_success" = false ]; do
if [ ! -d "flash-attention" ]; then
echo "Attempt $(($retry_count + 1))/$MAX_RETRIES: Cloning flash-attention repository..."
if git clone https://github.com/Dao-AILab/flash-attention.git; then
clone_success=true
cd flash-attention
git checkout v2.7.4.post1
else
retry_count=$((retry_count + 1))
if [ $retry_count -lt $MAX_RETRIES ]; then
sleep_time=$((2 ** retry_count))
echo "Clone failed. Retrying in $sleep_time seconds..."
sleep $sleep_time
else
echo "Failed to clone after $MAX_RETRIES attempts."
exit 1
fi
fi
else
clone_success=true
cd flash-attention
git checkout v2.7.4.post1
fi
done

# Try a different approach if the main build method fails
cd hopper/
echo "Building Flash Attention with MAX_JOBS=24..."
if ! MAX_JOBS=24 python setup.py build; then
echo "Default build method failed, trying alternative approach..."
# Try alternative build approach
if ! TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" pip install -e ..; then
echo "Alternative build method failed, trying pip install again with force..."
cd ../../
pip install flash-attn --force-reinstall
fi
fi

# Copy built libraries to the environment if available
if [ -d "build/lib.linux-x86_64-cpython-312" ]; then
echo "Copying built Flash Attention libraries to Python environment..."
cp -r build/lib.linux-x86_64-cpython-312/* $ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/
fi

create_compatibility_layer
echo "Flash Attention v3 has been compiled and installed to the shared environment"
fi
else
echo "Flash Attention v3 is already compiled"

# Always ensure the compatibility layer exists
if [ ! -f "$ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/flash_attn_interface.py" ]; then
create_compatibility_layer
fi
fi

# Install COSMOS Tokenizer VAE
if [ ! -d "$ROOT_DIR/environments/pollux_env/lib/python3.12/site-packages/cosmos_tokenizer" ]; then
echo "Installing COSMOS Tokenizer VAE..."
cd $(dirname "$0") # Go to the directory where this script is located

# Check if Cosmos-Tokenizer directory exists
if [ -d "apps/Cosmos-Tokenizer" ]; then
cd apps/Cosmos-Tokenizer
pip install -e .
cd ../..
echo "COSMOS Tokenizer VAE installed."
else
echo "Warning: apps/Cosmos-Tokenizer directory not found, skipping installation."
echo "If you need COSMOS Tokenizer, please make sure the repository is properly cloned."
# Optional: Clone the repository if it doesn't exist
# git clone https://github.com/your-org/Cosmos-Tokenizer.git apps/Cosmos-Tokenizer
fi
fi

# Final verification of Flash Attention installation
python -c "
try:
import flash_attn
print(f'Flash Attention {flash_attn.__version__} successfully installed')
import flash_attn_interface
print('Flash Attention Interface compatibility layer is working')
except ImportError as e:
print(f'Error importing Flash Attention: {e}')
exit(1)
"

echo "Environment setup complete. Activate with: conda activate $ROOT_DIR/environments/pollux_env"
43 changes: 43 additions & 0 deletions submit_pretrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
# Helper script to submit and monitor SLURM jobs

# Make sure scripts are executable
chmod +x setup_shared_env.sh
chmod +x train_castor.slurm
chmod +x diagnose_flash_attn.py

# Check if diagnostic mode is requested
if [ "$1" == "diagnose" ]; then
echo "Running flash-attention diagnostic tool..."
source $(conda info --base)/etc/profile.d/conda.sh
conda activate /mnt/pollux/environments/pollux_env
python diagnose_flash_attn.py
exit 0
fi

# Check if a custom partition is provided
if [ -n "$1" ] && [ "$1" != "diagnose" ]; then
PARTITION="$1"
echo "Using custom partition: $PARTITION"
# Submit the job with the specified partition
JOB_ID=$(sbatch --parsable --partition="$PARTITION" train_castor.slurm)
else
# Submit the job with the default partition in the script
JOB_ID=$(sbatch --parsable train_castor.slurm)
fi

echo "Job submitted with ID: $JOB_ID"
echo "Monitor with: squeue -j $JOB_ID"
echo "View logs with: tail -f ${JOB_ID}.out"
echo ""
echo "Quick commands:"
echo "---------------"
echo "View job status: scontrol show job $JOB_ID"
echo "Cancel job: scancel $JOB_ID"
echo "View resource usage: sstat --format=AveCPU,AveRSS,AveVMSize,MaxRSS,MaxVMSize -j $JOB_ID"
echo "Run diagnostic tool: ./submit_job.sh diagnose"

# Watch the job status
echo ""
echo "Watching job status (press Ctrl+C to exit):"
watch -n 10 squeue -j "$JOB_ID"
Loading