diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py index 37cd24512..91c0003b2 100644 --- a/segment_anything/build_sam.py +++ b/segment_anything/build_sam.py @@ -11,36 +11,39 @@ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer -def build_sam_vit_h(checkpoint=None): +def build_sam_vit_h(checkpoint=None, **kwargs): return _build_sam( encoder_embed_dim=1280, encoder_depth=32, encoder_num_heads=16, encoder_global_attn_indexes=[7, 15, 23, 31], checkpoint=checkpoint, + **kwargs, ) build_sam = build_sam_vit_h -def build_sam_vit_l(checkpoint=None): +def build_sam_vit_l(checkpoint=None, **kwargs): return _build_sam( encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16, encoder_global_attn_indexes=[5, 11, 17, 23], checkpoint=checkpoint, + **kwargs, ) -def build_sam_vit_b(checkpoint=None): +def build_sam_vit_b(checkpoint=None, **kwargs): return _build_sam( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=[2, 5, 8, 11], checkpoint=checkpoint, + **kwargs, ) @@ -58,7 +61,12 @@ def _build_sam( encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, + **kwargs, ): + window_size = kwargs.get("window_size", 14) + if window_size <= 0: + window_size = 14 + prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 @@ -75,7 +83,7 @@ def _build_sam( qkv_bias=True, use_rel_pos=True, global_attn_indexes=encoder_global_attn_indexes, - window_size=14, + window_size=window_size, out_chans=prompt_embed_dim, ), prompt_encoder=PromptEncoder( @@ -102,6 +110,6 @@ def _build_sam( sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, weights_only=False) sam.load_state_dict(state_dict) return sam diff --git a/segment_anything/modeling/image_encoder.py b/segment_anything/modeling/image_encoder.py index 66351d9d7..e98992a96 100644 --- a/segment_anything/modeling/image_encoder.py +++ b/segment_anything/modeling/image_encoder.py @@ -239,6 +239,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + rest = dict() + for k, v in state_dict.items(): + if 'rel_pos' in k: + my_rel_pos = getattr(self, k[len(prefix):]) + if my_rel_pos.shape[0] != v.shape[0]: + v = v.unsqueeze(0).permute(0, 2, 1) + v = F.interpolate(v, size=my_rel_pos.shape[0], mode='linear', align_corners=True) + v = v.squeeze(0).T + my_rel_pos.data.copy_(v) + else: + rest[k] = v + + return super()._load_from_state_dict(rest, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) + def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: """