@@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
9292 _assert (C % groups == 0 , '' )
9393 x_dtype = x .dtype
9494 x = x .reshape (B , groups , C // groups , H , W )
95- rms = x .float ().square ().mean (dim = (2 , 3 , 4 ), keepdim = True ).add (eps ).sqrt_ ().to (dtype = x_dtype )
95+ rms = x .float ().square ().mean (dim = (2 , 3 , 4 ), keepdim = True ).add (eps ).sqrt_ ().to (x_dtype )
9696 return rms .expand (x .shape ).reshape (B , C , H , W )
9797
9898
@@ -160,14 +160,14 @@ def forward(self, x):
160160 n = x .numel () / x .shape [1 ]
161161 self .running_var .copy_ (
162162 self .running_var * (1 - self .momentum ) +
163- var .detach ().to (dtype = self .running_var .dtype ) * self .momentum * (n / (n - 1 )))
163+ var .detach ().to (self .running_var .dtype ) * self .momentum * (n / (n - 1 )))
164164 else :
165165 var = self .running_var
166- var = var .to (dtype = x_dtype ).view (v_shape )
166+ var = var .to (x_dtype ).view (v_shape )
167167 left = var .add (self .eps ).sqrt_ ()
168168 right = (x + 1 ) * instance_rms (x , self .eps )
169169 x = x / left .max (right )
170- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
170+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
171171
172172
173173class EvoNorm2dB2 (nn .Module ):
@@ -195,14 +195,14 @@ def forward(self, x):
195195 n = x .numel () / x .shape [1 ]
196196 self .running_var .copy_ (
197197 self .running_var * (1 - self .momentum ) +
198- var .detach ().to (dtype = self .running_var .dtype ) * self .momentum * (n / (n - 1 )))
198+ var .detach ().to (self .running_var .dtype ) * self .momentum * (n / (n - 1 )))
199199 else :
200200 var = self .running_var
201- var = var .to (dtype = x_dtype ).view (v_shape )
201+ var = var .to (x_dtype ).view (v_shape )
202202 left = var .add (self .eps ).sqrt_ ()
203203 right = instance_rms (x , self .eps ) - x
204204 x = x / left .max (right )
205- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
205+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
206206
207207
208208class EvoNorm2dS0 (nn .Module ):
@@ -231,9 +231,9 @@ def forward(self, x):
231231 x_dtype = x .dtype
232232 v_shape = (1 , - 1 , 1 , 1 )
233233 if self .v is not None :
234- v = self .v .view (v_shape ).to (dtype = x_dtype )
234+ v = self .v .view (v_shape ).to (x_dtype )
235235 x = x * (x * v ).sigmoid () / group_std (x , self .groups , self .eps )
236- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
236+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
237237
238238
239239class EvoNorm2dS0a (EvoNorm2dS0 ):
@@ -247,10 +247,10 @@ def forward(self, x):
247247 v_shape = (1 , - 1 , 1 , 1 )
248248 d = group_std (x , self .groups , self .eps )
249249 if self .v is not None :
250- v = self .v .view (v_shape ).to (dtype = x_dtype )
250+ v = self .v .view (v_shape ).to (x_dtype )
251251 x = x * (x * v ).sigmoid ()
252252 x = x / d
253- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
253+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
254254
255255
256256class EvoNorm2dS1 (nn .Module ):
@@ -284,7 +284,7 @@ def forward(self, x):
284284 v_shape = (1 , - 1 , 1 , 1 )
285285 if self .apply_act :
286286 x = self .act (x ) / group_std (x , self .groups , self .eps )
287- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
287+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
288288
289289
290290class EvoNorm2dS1a (EvoNorm2dS1 ):
@@ -299,7 +299,7 @@ def forward(self, x):
299299 x_dtype = x .dtype
300300 v_shape = (1 , - 1 , 1 , 1 )
301301 x = self .act (x ) / group_std (x , self .groups , self .eps )
302- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
302+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
303303
304304
305305class EvoNorm2dS2 (nn .Module ):
@@ -332,7 +332,7 @@ def forward(self, x):
332332 v_shape = (1 , - 1 , 1 , 1 )
333333 if self .apply_act :
334334 x = self .act (x ) / group_rms (x , self .groups , self .eps )
335- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
335+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
336336
337337
338338class EvoNorm2dS2a (EvoNorm2dS2 ):
@@ -347,4 +347,4 @@ def forward(self, x):
347347 x_dtype = x .dtype
348348 v_shape = (1 , - 1 , 1 , 1 )
349349 x = self .act (x ) / group_rms (x , self .groups , self .eps )
350- return x * self .weight .view (v_shape ).to (dtype = x_dtype ) + self .bias .view (v_shape ).to (dtype = x_dtype )
350+ return x * self .weight .view (v_shape ).to (x_dtype ) + self .bias .view (v_shape ).to (x_dtype )
0 commit comments