@@ -52,17 +52,17 @@ def _maybe_aqt_einsum(quant: Quant):
5252class AttentionOp (nn .Module ):
5353 mesh : Mesh
5454 attention_kernel : str
55- scale : int
55+ scale : float
5656 heads : int
5757 dim_head : int
5858 use_memory_efficient_attention : bool = False
5959 split_head_dim : bool = False
6060 float32_qk_product : bool = True
6161 flash_axis_names : AxisNames = (BATCH , HEAD , LENGTH , D_KV )
6262 flash_min_seq_length : int = 4096
63- flash_block_sizes : BlockSizes = None
63+ flash_block_sizes : BlockSizes | None = None
6464 dtype : DType = jnp .float32
65- quant : Quant = None
65+ quant : Quant | None = None
6666
6767 def setup (self ):
6868 if self .attention_kernel == "cudnn_flash_te" :
@@ -79,7 +79,7 @@ def setup(self):
7979 dtype = self .dtype ,
8080 # float32_logits=self.float32_logits,
8181 qkv_layout = "BSHD_BSHD_BSHD" , # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
82- scale_factor = self .scale ,
82+ scale_factor = float ( self .scale ) ,
8383 transpose_batch_sequence = False ,
8484 )
8585
@@ -415,15 +415,15 @@ class FlaxFluxAttention(nn.Module):
415415 split_head_dim : bool = False
416416 attention_kernel : str = "dot_product"
417417 flash_min_seq_length : int = 4096
418- flash_block_sizes : BlockSizes = None
419- mesh : jax .sharding .Mesh = None
418+ flash_block_sizes : BlockSizes | None = None
419+ mesh : jax .sharding .Mesh | None = None
420420 dtype : jnp .dtype = jnp .float32
421421 weights_dtype : jnp .dtype = jnp .float32
422422 query_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
423423 key_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
424424 value_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
425425 out_axis_names : AxisNames = (BATCH , LENGTH , EMBED )
426- precision : jax .lax .Precision = None
426+ precision : jax .lax .Precision | None = None
427427 qkv_bias : bool = False
428428
429429 def setup (self ):
@@ -619,16 +619,16 @@ class FlaxAttention(nn.Module):
619619 split_head_dim : bool = False
620620 attention_kernel : str = "dot_product"
621621 flash_min_seq_length : int = 4096
622- flash_block_sizes : BlockSizes = None
623- mesh : jax .sharding .Mesh = None
622+ flash_block_sizes : BlockSizes | None = None
623+ mesh : jax .sharding .Mesh | None = None
624624 dtype : jnp .dtype = jnp .float32
625625 weights_dtype : jnp .dtype = jnp .float32
626626 query_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
627627 key_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
628628 value_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
629629 out_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
630- precision : jax .lax .Precision = None
631- quant : Quant = None
630+ precision : jax .lax .Precision | None = None
631+ quant : Quant | None = None
632632
633633 def setup (self ):
634634
@@ -762,10 +762,10 @@ class FlaxBasicTransformerBlock(nn.Module):
762762 split_head_dim : bool = False
763763 attention_kernel : str = "dot_product"
764764 flash_min_seq_length : int = 4096
765- flash_block_sizes : BlockSizes = None
766- mesh : jax .sharding .Mesh = None
767- precision : jax .lax .Precision = None
768- quant : Quant = None
765+ flash_block_sizes : BlockSizes | None = None
766+ mesh : jax .sharding .Mesh | None = None
767+ precision : jax .lax .Precision | None = None
768+ quant : Quant | None = None
769769
770770 def setup (self ):
771771 # self attention (or cross_attention if only_cross_attention is True)
@@ -890,12 +890,12 @@ class FlaxTransformer2DModel(nn.Module):
890890 split_head_dim : bool = False
891891 attention_kernel : str = "dot_product"
892892 flash_min_seq_length : int = 4096
893- flash_block_sizes : BlockSizes = None
894- mesh : jax .sharding .Mesh = None
893+ flash_block_sizes : BlockSizes | None = None
894+ mesh : jax .sharding .Mesh | None = None
895895 norm_num_groups : int = 32
896- precision : jax .lax .Precision = None
896+ precision : jax .lax .Precision | None = None
897897 hidden_state_axis_names : AxisNames = (BATCH , LENGTH , D_KV )
898- quant : Quant = (None ,)
898+ quant : Quant | tuple [ None ] = (None ,)
899899
900900 def setup (self ):
901901 self .norm = nn .GroupNorm (num_groups = self .norm_num_groups , epsilon = 1e-5 , dtype = self .dtype , param_dtype = self .weights_dtype )
@@ -1019,7 +1019,7 @@ class FlaxFeedForward(nn.Module):
10191019 dropout : float = 0.0
10201020 dtype : jnp .dtype = jnp .float32
10211021 weights_dtype : jnp .dtype = jnp .float32
1022- precision : jax .lax .Precision = None
1022+ precision : jax .lax .Precision | None = None
10231023
10241024 def setup (self ):
10251025 # The second linear layer needs to be called
@@ -1051,7 +1051,7 @@ class FlaxGEGLU(nn.Module):
10511051 dropout : float = 0.0
10521052 dtype : jnp .dtype = jnp .float32
10531053 weights_dtype : jnp .dtype = jnp .float32
1054- precision : jax .lax .Precision = None
1054+ precision : jax .lax .Precision | None = None
10551055
10561056 def setup (self ):
10571057 inner_dim = self .dim * 4
0 commit comments