@@ -166,7 +166,7 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.
166166 def forward (self , x ):
167167 B , N , C = x .shape
168168 qkv = self .qkv (x ).reshape (B , N , 3 , self .num_heads , C // self .num_heads ).permute (2 , 0 , 3 , 1 , 4 )
169- q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # make torchscript happy (cannot use tensor as tuple)
169+ q , k , v = qkv [0 ], qkv [1 ], qkv [2 ] # make torchscript happy (cannot use tensor as tuple)
170170
171171 attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
172172 attn = attn .softmax (dim = - 1 )
@@ -663,7 +663,7 @@ def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
663663 """
664664 model_kwargs = dict (patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , ** kwargs )
665665 model = _create_vision_transformer (
666- 'vit_deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
666+ 'vit_deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
667667 return model
668668
669669
@@ -674,7 +674,7 @@ def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
674674 """
675675 model_kwargs = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , ** kwargs )
676676 model = _create_vision_transformer (
677- 'vit_deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
677+ 'vit_deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
678678 return model
679679
680680
@@ -685,7 +685,7 @@ def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
685685 """
686686 model_kwargs = dict (patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , ** kwargs )
687687 model = _create_vision_transformer (
688- 'vit_deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
688+ 'vit_deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , ** model_kwargs )
689689 return model
690690
691691
0 commit comments