Skip to content

Commit ff35c30

Browse files
committed
Remove noop spmd_mode check, correct type annotations in attention flax
Signed-off-by: Kunjan patel <kunjanp@google.com>
1 parent 7284ca0 commit ff35c30

File tree

2 files changed

+42
-43
lines changed

2 files changed

+42
-43
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ def _maybe_aqt_einsum(quant: Quant):
5252
class AttentionOp(nn.Module):
5353
mesh: Mesh
5454
attention_kernel: str
55-
scale: int
55+
scale: float
5656
heads: int
5757
dim_head: int
5858
use_memory_efficient_attention: bool = False
5959
split_head_dim: bool = False
6060
float32_qk_product: bool = True
6161
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
6262
flash_min_seq_length: int = 4096
63-
flash_block_sizes: BlockSizes = None
63+
flash_block_sizes: BlockSizes | None = None
6464
dtype: DType = jnp.float32
65-
quant: Quant = None
65+
quant: Quant | None = None
6666

6767
def setup(self):
6868
if self.attention_kernel == "cudnn_flash_te":
@@ -79,7 +79,7 @@ def setup(self):
7979
dtype=self.dtype,
8080
# float32_logits=self.float32_logits,
8181
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
82-
scale_factor=self.scale,
82+
scale_factor=float(self.scale),
8383
transpose_batch_sequence=False,
8484
)
8585

@@ -415,15 +415,15 @@ class FlaxFluxAttention(nn.Module):
415415
split_head_dim: bool = False
416416
attention_kernel: str = "dot_product"
417417
flash_min_seq_length: int = 4096
418-
flash_block_sizes: BlockSizes = None
419-
mesh: jax.sharding.Mesh = None
418+
flash_block_sizes: BlockSizes | None = None
419+
mesh: jax.sharding.Mesh | None = None
420420
dtype: jnp.dtype = jnp.float32
421421
weights_dtype: jnp.dtype = jnp.float32
422422
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
423423
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
424424
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
425425
out_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
426-
precision: jax.lax.Precision = None
426+
precision: jax.lax.Precision | None = None
427427
qkv_bias: bool = False
428428

429429
def setup(self):
@@ -619,16 +619,16 @@ class FlaxAttention(nn.Module):
619619
split_head_dim: bool = False
620620
attention_kernel: str = "dot_product"
621621
flash_min_seq_length: int = 4096
622-
flash_block_sizes: BlockSizes = None
623-
mesh: jax.sharding.Mesh = None
622+
flash_block_sizes: BlockSizes | None = None
623+
mesh: jax.sharding.Mesh | None = None
624624
dtype: jnp.dtype = jnp.float32
625625
weights_dtype: jnp.dtype = jnp.float32
626626
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
627627
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
628628
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
629629
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
630-
precision: jax.lax.Precision = None
631-
quant: Quant = None
630+
precision: jax.lax.Precision | None = None
631+
quant: Quant | None = None
632632

633633
def setup(self):
634634

@@ -762,10 +762,10 @@ class FlaxBasicTransformerBlock(nn.Module):
762762
split_head_dim: bool = False
763763
attention_kernel: str = "dot_product"
764764
flash_min_seq_length: int = 4096
765-
flash_block_sizes: BlockSizes = None
766-
mesh: jax.sharding.Mesh = None
767-
precision: jax.lax.Precision = None
768-
quant: Quant = None
765+
flash_block_sizes: BlockSizes | None = None
766+
mesh: jax.sharding.Mesh | None = None
767+
precision: jax.lax.Precision | None = None
768+
quant: Quant | None = None
769769

770770
def setup(self):
771771
# self attention (or cross_attention if only_cross_attention is True)
@@ -890,12 +890,12 @@ class FlaxTransformer2DModel(nn.Module):
890890
split_head_dim: bool = False
891891
attention_kernel: str = "dot_product"
892892
flash_min_seq_length: int = 4096
893-
flash_block_sizes: BlockSizes = None
894-
mesh: jax.sharding.Mesh = None
893+
flash_block_sizes: BlockSizes | None = None
894+
mesh: jax.sharding.Mesh | None = None
895895
norm_num_groups: int = 32
896-
precision: jax.lax.Precision = None
896+
precision: jax.lax.Precision | None = None
897897
hidden_state_axis_names: AxisNames = (BATCH, LENGTH, D_KV)
898-
quant: Quant = (None,)
898+
quant: Quant | tuple[None] = (None,)
899899

900900
def setup(self):
901901
self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype)
@@ -1019,7 +1019,7 @@ class FlaxFeedForward(nn.Module):
10191019
dropout: float = 0.0
10201020
dtype: jnp.dtype = jnp.float32
10211021
weights_dtype: jnp.dtype = jnp.float32
1022-
precision: jax.lax.Precision = None
1022+
precision: jax.lax.Precision | None = None
10231023

10241024
def setup(self):
10251025
# The second linear layer needs to be called
@@ -1051,7 +1051,7 @@ class FlaxGEGLU(nn.Module):
10511051
dropout: float = 0.0
10521052
dtype: jnp.dtype = jnp.float32
10531053
weights_dtype: jnp.dtype = jnp.float32
1054-
precision: jax.lax.Precision = None
1054+
precision: jax.lax.Precision | None = None
10551055

10561056
def setup(self):
10571057
inner_dim = self.dim * 4

src/maxdiffusion/train_utils.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323

2424
def get_first_step(state):
25-
with jax.spmd_mode("allow_all"):
26-
return int(state.step)
25+
return int(state.step)
2726

2827

2928
def load_next_batch(train_iter, example_batch, config):
@@ -101,27 +100,27 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step
101100

102101
def write_metrics_to_tensorboard(writer, metrics, step, config):
103102
"""Writes metrics to tensorboard"""
104-
with jax.spmd_mode("allow_all"):
105-
if jax.process_index() == 0:
106-
for metric_name in metrics.get("scalar", []):
107-
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
108-
for metric_name in metrics.get("scalars", []):
109-
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
110-
111-
full_log = step % config.log_period == 0
112-
if jax.process_index() == 0:
113-
max_logging.log(
114-
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
115-
step,
116-
metrics["scalar"]["perf/step_time_seconds"],
117-
metrics["scalar"]["perf/per_device_tflops_per_sec"],
118-
float(metrics["scalar"]["learning/loss"]),
119-
)
120-
)
121103

122-
if full_log and jax.process_index() == 0:
123-
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
124-
writer.flush()
104+
if jax.process_index() == 0:
105+
for metric_name in metrics.get("scalar", []):
106+
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
107+
for metric_name in metrics.get("scalars", []):
108+
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
109+
110+
full_log = step % config.log_period == 0
111+
if jax.process_index() == 0:
112+
max_logging.log(
113+
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
114+
step,
115+
metrics["scalar"]["perf/step_time_seconds"],
116+
metrics["scalar"]["perf/per_device_tflops_per_sec"],
117+
float(metrics["scalar"]["learning/loss"]),
118+
)
119+
)
120+
121+
if full_log and jax.process_index() == 0:
122+
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
123+
writer.flush()
125124

126125

127126
def get_params_to_save(params):

0 commit comments

Comments
 (0)