88"""
99import types
1010import functools
11+ from typing import Optional
1112
1213from .evo_norm import *
1314from .filter_response_norm import FilterResponseNormAct2d , FilterResponseNormTlu2d
14- from .norm_act import BatchNormAct2d , GroupNormAct , LayerNormAct , LayerNormAct2d , RmsNormAct , RmsNormAct2d
15+ from .norm_act import (
16+ BatchNormAct2d ,
17+ GroupNormAct ,
18+ GroupNorm1Act ,
19+ LayerNormAct ,
20+ LayerNormActFp32 ,
21+ LayerNormAct2d ,
22+ LayerNormAct2dFp32 ,
23+ RmsNormAct ,
24+ RmsNormActFp32 ,
25+ RmsNormAct2d ,
26+ RmsNormAct2dFp32 ,
27+ )
1528from .inplace_abn import InplaceAbn
29+ from .typing import LayerType
1630
1731_NORM_ACT_MAP = dict (
1832 batchnorm = BatchNormAct2d ,
1933 batchnorm2d = BatchNormAct2d ,
2034 groupnorm = GroupNormAct ,
21- groupnorm1 = functools . partial ( GroupNormAct , num_groups = 1 ) ,
35+ groupnorm1 = GroupNorm1Act ,
2236 layernorm = LayerNormAct ,
2337 layernorm2d = LayerNormAct2d ,
38+ layernormfp32 = LayerNormActFp32 ,
39+ layernorm2dfp32 = LayerNormAct2dFp32 ,
2440 evonormb0 = EvoNorm2dB0 ,
2541 evonormb1 = EvoNorm2dB1 ,
2642 evonormb2 = EvoNorm2dB2 ,
3652 iabn = InplaceAbn ,
3753 rmsnorm = RmsNormAct ,
3854 rmsnorm2d = RmsNormAct2d ,
55+ rmsnormfp32 = RmsNormActFp32 ,
56+ rmsnorm2dfp32 = RmsNormAct2dFp32 ,
3957)
4058_NORM_ACT_TYPES = {m for n , m in _NORM_ACT_MAP .items ()}
59+ # Reverse map from base norm layer names to norm+act layer classes
60+ _NORM_TO_NORM_ACT_MAP = dict (
61+ batchnorm = BatchNormAct2d ,
62+ batchnorm2d = BatchNormAct2d ,
63+ groupnorm = GroupNormAct ,
64+ groupnorm1 = GroupNorm1Act ,
65+ layernorm = LayerNormAct ,
66+ layernorm2d = LayerNormAct2d ,
67+ layernormfp32 = LayerNormActFp32 ,
68+ layernorm2dfp32 = LayerNormAct2dFp32 ,
69+ rmsnorm = RmsNormAct ,
70+ rmsnorm2d = RmsNormAct2d ,
71+ rmsnormfp32 = RmsNormActFp32 ,
72+ rmsnorm2dfp32 = RmsNormAct2dFp32 ,
73+ )
4174# has act_layer arg to define act type
4275_NORM_ACT_REQUIRES_ARG = {
4376 BatchNormAct2d ,
4477 GroupNormAct ,
78+ GroupNorm1Act ,
4579 LayerNormAct ,
4680 LayerNormAct2d ,
81+ LayerNormActFp32 ,
82+ LayerNormAct2dFp32 ,
4783 FilterResponseNormAct2d ,
4884 InplaceAbn ,
4985 RmsNormAct ,
5086 RmsNormAct2d ,
87+ RmsNormActFp32 ,
88+ RmsNormAct2dFp32 ,
5189}
5290
5391
54- def create_norm_act_layer (layer_name , num_features , act_layer = None , apply_act = True , jit = False , ** kwargs ):
92+ def create_norm_act_layer (
93+ layer_name : LayerType ,
94+ num_features : int ,
95+ act_layer : Optional [LayerType ] = None ,
96+ apply_act : bool = True ,
97+ jit : bool = False ,
98+ ** kwargs ,
99+ ):
55100 layer = get_norm_act_layer (layer_name , act_layer = act_layer )
56101 layer_instance = layer (num_features , apply_act = apply_act , ** kwargs )
57102 if jit :
58103 layer_instance = torch .jit .script (layer_instance )
59104 return layer_instance
60105
61106
62- def get_norm_act_layer (norm_layer , act_layer = None ):
107+ def get_norm_act_layer (
108+ norm_layer : LayerType ,
109+ act_layer : Optional [LayerType ] = None ,
110+ ):
63111 if norm_layer is None :
64112 return None
65113 assert isinstance (norm_layer , (type , str , types .FunctionType , functools .partial ))
@@ -82,26 +130,16 @@ def get_norm_act_layer(norm_layer, act_layer=None):
82130 # if function type, must be a lambda/fn that creates a norm_act layer
83131 norm_act_layer = norm_layer
84132 else :
133+ # Use reverse map to find the corresponding norm+act layer
85134 type_name = norm_layer .__name__ .lower ()
86- if type_name .startswith ('batchnorm' ):
87- norm_act_layer = BatchNormAct2d
88- elif type_name .startswith ('groupnorm' ):
89- norm_act_layer = GroupNormAct
90- elif type_name .startswith ('groupnorm1' ):
91- norm_act_layer = functools .partial (GroupNormAct , num_groups = 1 )
92- elif type_name .startswith ('layernorm2d' ):
93- norm_act_layer = LayerNormAct2d
94- elif type_name .startswith ('layernorm' ):
95- norm_act_layer = LayerNormAct
96- elif type_name .startswith ('rmsnorm2d' ):
97- norm_act_layer = RmsNormAct2d
98- else :
99- assert False , f"No equivalent norm_act layer for { type_name } "
135+ norm_act_layer = _NORM_TO_NORM_ACT_MAP .get (type_name , None )
136+ assert norm_act_layer is not None , f"No equivalent norm_act layer for { type_name } "
100137
101138 if norm_act_layer in _NORM_ACT_REQUIRES_ARG :
102139 # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
103140 # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
104141 norm_act_kwargs .setdefault ('act_layer' , act_layer )
105142 if norm_act_kwargs :
106143 norm_act_layer = functools .partial (norm_act_layer , ** norm_act_kwargs ) # bind/rebind args
144+
107145 return norm_act_layer
0 commit comments