Skip to content

Commit 8ba2038

Browse files
YassineYousfirwightman
authored andcommitted
fast_vit: propagate act_layer argument
1 parent 95ba901 commit 8ba2038

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

timm/models/fastvit.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,8 @@ def _fuse_bn(
421421
def convolutional_stem(
422422
in_chs: int,
423423
out_chs: int,
424-
inference_mode: bool = False
424+
inference_mode: bool = False,
425+
act_layer: nn.Module = nn.GELU,
425426
) -> nn.Sequential:
426427
"""Build convolutional stem with MobileOne blocks.
427428
@@ -439,6 +440,7 @@ def convolutional_stem(
439440
out_chs=out_chs,
440441
kernel_size=3,
441442
stride=2,
443+
act_layer=act_layer,
442444
inference_mode=inference_mode,
443445
),
444446
MobileOneBlock(
@@ -447,13 +449,15 @@ def convolutional_stem(
447449
kernel_size=3,
448450
stride=2,
449451
group_size=1,
452+
act_layer=act_layer,
450453
inference_mode=inference_mode,
451454
),
452455
MobileOneBlock(
453456
in_chs=out_chs,
454457
out_chs=out_chs,
455458
kernel_size=1,
456-
stride=1,
459+
stride=1,
460+
act_layer=act_layer,
457461
inference_mode=inference_mode,
458462
),
459463
)
@@ -1121,6 +1125,7 @@ def __init__(
11211125
in_chans,
11221126
embed_dims[0],
11231127
inference_mode,
1128+
act_layer
11241129
)
11251130

11261131
# Build the main stages of the network architecture

0 commit comments

Comments
 (0)