Skip to content

Commit fd53cde

Browse files
authored
Update Python version to 3.12, Update WAN to match new Flax version (#244)
* match new flax version * flax version>=0.11.0 requires python>=3.11 * update default install python version to 3.12 * update readme
1 parent c3bf323 commit fd53cde

File tree

7 files changed

+63
-13
lines changed

7 files changed

+63
-13
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ MaxDiffusion supports
7171

7272
We recommend starting with a single TPU host and then moving to multihost.
7373

74-
Minimum requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0.
74+
Minimum requirements: Ubuntu Version 22.04, Python 3.12 and Tensorflow >= 2.12.0.
7575

7676
## Getting Started:
7777

docs/getting_started/first_run.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,27 @@ multiple hosts.
1010

1111
1. [Create and SSH to a single-host TPU (v6-8). ](https://cloud.google.com/tpu/docs/users-guide-tpu-vm#creating_a_cloud_tpu_vm_with_gcloud)
1212
* You can find here [here](https://cloud.google.com/tpu/docs/regions-zones) the list of zones that support the v6(Trillium) TPUs
13-
* We recommend using the base VM image "v2-alpha-tpuv6e", which meets the version requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0
13+
* We recommend using the base VM image "v2-alpha-tpuv6e", which meets the version requirements: Ubuntu Version 22.04, Python 3.12 and Tensorflow >= 2.12.0
1414

1515
1. Clone MaxDiffusion in your TPU VM.
16-
```bash
16+
```
1717
git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
1818
cd maxdiffusion
1919
```
2020

2121
1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running:
22-
```bash
22+
```
23+
# If a Python 3.12+ virtual environment doesn't already exist, you'll need to run the install command twice.
2324
bash setup.sh MODE=stable DEVICE=tpu
2425
```
2526

27+
1. Active your virtual environment:
28+
```
29+
# Replace with your virtual environment name if not using this default name
30+
venv_name="maxdiffusion_venv"
31+
source ~/$venv_name/bin/activate
32+
```
33+
2634
## Getting Starting: Multihost development
2735

2836
[GKE, recommended] [Running MaxDiffusion with xpk](run_maxdiffusion_via_xpk.md) - Quick Experimentation and Production support

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ grain
55
google-cloud-storage>=2.17.0
66
absl-py
77
datasets
8-
flax>=0.10.2
8+
flax>=0.11.0
99
optax>=0.2.3
1010
torch>=2.6.0
1111
torchvision>=0.20.1

setup.sh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,46 @@
2323
set -e
2424
export DEBIAN_FRONTEND=noninteractive
2525

26+
echo "Checking Python version..."
27+
# This command will fail if the Python version is less than 3.12
28+
if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; then
29+
# If the command fails, print an error
30+
CURRENT_VERSION=$(python3 --version 2>&1) # Get the full version string
31+
echo -e "\n\e[31mERROR: Outdated Python Version! You are currently using $CURRENT_VERSION, but MaxDiffusion requires Python version 3.12 or higher.\e[0m"
32+
# Ask the user if they want to create a virtual environment with uv
33+
read -p "Would you like to create a Python 3.12 virtual environment using uv? (y/n) " -n 1 -r
34+
echo # Move to a new line after input
35+
if [[ $REPLY =~ ^[Yy]$ ]]; then
36+
# Check if uv is installed first; if not, install uv
37+
if ! command -v uv &> /dev/null; then
38+
pip install uv
39+
fi
40+
maxdiffusion_dir=$(pwd)
41+
cd
42+
# Ask for the venv name
43+
read -p "Please enter a name for your new virtual environment (default: maxdiffusion_venv): " venv_name
44+
# Use a default name if the user provides no input
45+
if [ -z "$venv_name" ]; then
46+
venv_name="maxdiffusion_venv"
47+
echo "No name provided. Using default name: '$venv_name'"
48+
fi
49+
echo "Creating virtual environment '$venv_name' with Python 3.12..."
50+
uv venv --python 3.12 "$venv_name" --seed
51+
printf '%s\n' "$(realpath -- "$venv_name")" >> /tmp/venv_created
52+
echo -e "\n\e[32mVirtual environment '$venv_name' created successfully!\e[0m"
53+
echo "To activate it, run the following command:"
54+
echo -e "\e[33m source ~/$venv_name/bin/activate\e[0m"
55+
echo "After activating the environment, please re-run this script."
56+
cd $maxdiffusion_dir
57+
else
58+
echo "Exiting. Please upgrade your Python environment to continue."
59+
fi
60+
# Exit the script since the initial Python check failed
61+
exit 1
62+
fi
63+
echo "Python version check passed. Continuing with script."
64+
echo "--------------------------------------------------"
65+
2666
(sudo bash || bash) <<'EOF'
2767
mkdir -p /etc/needrestart/conf.d
2868
echo '$nrconf{restart} = "a";' > /etc/needrestart/conf.d/99-noninteractive.conf

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,8 +795,8 @@ def __init__(
795795

796796
self.drop_out = nnx.Dropout(dropout)
797797

798-
self.norm_q = None
799-
self.norm_k = None
798+
self.norm_q = nnx.data(None)
799+
self.norm_k = nnx.data(None)
800800
if qk_norm is not None:
801801
self.norm_q = nnx.RMSNorm(
802802
num_features=self.inner_dim,

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def __init__(
225225
):
226226
self.dim = dim
227227
self.mode = mode
228-
self.time_conv = None
228+
self.time_conv = nnx.data(None)
229229

230230
if mode == "upsample2d":
231231
self.resample = nnx.Sequential(
@@ -554,8 +554,8 @@ def __init__(
554554
precision=precision,
555555
)
556556
)
557-
self.attentions = attentions
558-
self.resnets = resnets
557+
self.attentions = nnx.data(attentions)
558+
self.resnets = nnx.data(resnets)
559559

560560
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
561561
x = self.resnets[0](x, feat_cache, feat_idx)
@@ -601,10 +601,10 @@ def __init__(
601601
)
602602
)
603603
current_dim = out_dim
604-
self.resnets = resnets
604+
self.resnets = nnx.data(resnets)
605605

606606
# Add upsampling layer if needed.
607-
self.upsamplers = None
607+
self.upsamplers = nnx.data(None)
608608
if upsample_mode is not None:
609609
self.upsamplers = [
610610
WanResample(
@@ -710,6 +710,7 @@ def __init__(
710710
)
711711
)
712712
scale /= 2.0
713+
self.down_blocks = nnx.data(self.down_blocks)
713714

714715
# middle_blocks
715716
self.mid_block = WanMidBlock(
@@ -873,6 +874,7 @@ def __init__(
873874
# Update scale for next iteration
874875
if upsample_mode is not None:
875876
scale *= 2.0
877+
self.up_blocks = nnx.data(self.up_blocks)
876878

877879
# output blocks
878880
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def __init__(
209209
inner_dim = int(dim * mult)
210210
dim_out = dim_out if dim_out is not None else dim
211211

212-
self.act_fn = None
212+
self.act_fn = nnx.data(None)
213213
if activation_fn == "gelu-approximate":
214214
self.act_fn = ApproximateGELU(
215215
rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision

0 commit comments

Comments
 (0)