diff --git a/mmdet/models/dense_heads/detr_head.py b/mmdet/models/dense_heads/detr_head.py index 1046400f78c..4f4e554cf38 100644 --- a/mmdet/models/dense_heads/detr_head.py +++ b/mmdet/models/dense_heads/detr_head.py @@ -12,11 +12,13 @@ from mmdet.registry import MODELS, TASK_UTILS from mmdet.structures import SampleList -from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, bbox_overlaps +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcywh) from mmdet.utils import (ConfigType, InstanceList, OptInstanceList, OptMultiConfig, reduce_mean) -from ..utils import multi_apply from ..losses import QualityFocalLoss +from ..utils import multi_apply + @MODELS.register_module() class DETRHead(BaseModule): @@ -424,7 +426,8 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, gt_instances=gt_instances, img_meta=img_meta) - gt_bboxes = gt_instances.bboxes + # The type of `bboxes` should be consistent with the `cls_score` + gt_bboxes = gt_instances.bboxes.type_as(cls_score) gt_labels = gt_instances.labels pos_inds = torch.nonzero( assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() @@ -448,8 +451,14 @@ def _get_targets_single(self, cls_score: Tensor, bbox_pred: Tensor, # DETR regress the relative position of boxes (cxcywh) in the image. # Thus the learning target should be normalized by the image size, also # the box format should be converted from defaultly x1y1x2y2 to cxcywh. - pos_gt_bboxes_normalized = pos_gt_bboxes / factor + + # `pos_gt_bboxes / factor` will return a float tensor by default. + # Use `type_as` here to make sure the dtype of gt_bboxes is the same as + # the pred_bboxes. + pos_gt_bboxes_normalized = (pos_gt_bboxes / factor).type_as( + bbox_targets) pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized) + pos_gt_bboxes_targets = pos_gt_bboxes_targets bbox_targets[pos_inds] = pos_gt_bboxes_targets return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds) diff --git a/mmdet/models/layers/transformer/dino_layers.py b/mmdet/models/layers/transformer/dino_layers.py index 64610d0a7c0..01c0d6afb10 100644 --- a/mmdet/models/layers/transformer/dino_layers.py +++ b/mmdet/models/layers/transformer/dino_layers.py @@ -493,12 +493,18 @@ def collate_dn_queries(self, input_label_query: Tensor, mapper = (batch_idx_expand, map_query_index) batched_label_query = torch.zeros( - batch_size, num_denoising_queries, self.embed_dims, device=device) + batch_size, num_denoising_queries, self.embed_dims, device=device, + dtype=input_label_query.dtype) + # `input_label_query` is extracted from `nn.Embedding`, of which dtype + # has been converted into the target dtype. + # However the dtype of `batched_label_query` is always `float32` batched_bbox_query = torch.zeros( - batch_size, num_denoising_queries, 4, device=device) + batch_size, num_denoising_queries, 4, device=device, + dtype=input_label_query.dtype) batched_label_query[mapper] = input_label_query - batched_bbox_query[mapper] = input_bbox_query + batched_bbox_query[mapper] = input_bbox_query.to( + dtype=input_label_query.dtype) return batched_label_query, batched_bbox_query def generate_dn_mask(self, max_num_target: int, num_groups: int, diff --git a/mmdet/models/losses/gfocal_loss.py b/mmdet/models/losses/gfocal_loss.py index b3a1172207e..816eda25b33 100644 --- a/mmdet/models/losses/gfocal_loss.py +++ b/mmdet/models/losses/gfocal_loss.py @@ -45,9 +45,9 @@ def quality_focal_loss(pred, target, beta=2.0): pos_label = label[pos].long() # positives are supervised by bbox quality (IoU) score scale_factor = score[pos] - pred_sigmoid[pos, pos_label] - loss[pos, pos_label] = F.binary_cross_entropy_with_logits( + loss[pos, pos_label] = (F.binary_cross_entropy_with_logits( pred[pos, pos_label], score[pos], - reduction='none') * scale_factor.abs().pow(beta) + reduction='none') * scale_factor.abs().pow(beta)).type_as(loss) loss = loss.sum(dim=1, keepdim=False) return loss diff --git a/projects/CO-DETR/codetr/co_dino_head.py b/projects/CO-DETR/codetr/co_dino_head.py index 90c5c06beee..bdeb4b53071 100644 --- a/projects/CO-DETR/codetr/co_dino_head.py +++ b/projects/CO-DETR/codetr/co_dino_head.py @@ -1,26 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Tuple + import torch import torch.nn as nn import torch.nn.functional as F -import copy -from typing import Dict, List, Tuple -from torch import Tensor +from mmcv.cnn import Linear +from mmcv.ops import batched_nms, interpolate from mmengine.structures import InstanceData -from mmcv.ops import batched_nms -from mmdet.utils import InstanceList, reduce_mean -from mmdet.structures import SampleList -from mmdet.registry import MODELS, TASK_UTILS -from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh, bbox_overlaps) -from mmdet.models.utils import multi_apply, unpack_gt_instances +from torch import Tensor + +from mmdet.models import DINOHead +from mmdet.models.layers import CdnQueryGenerator from mmdet.models.layers.transformer import inverse_sigmoid -from mmcv.ops import batched_nms from mmdet.models.task_modules.samplers import PseudoSampler -from mmcv.cnn import Linear - +from mmdet.models.utils import multi_apply, unpack_gt_instances +from mmdet.registry import MODELS, TASK_UTILS +from mmdet.structures import SampleList +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcywh) from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, OptInstanceList, reduce_mean) -from mmdet.models.layers import CdnQueryGenerator -from mmdet.models import DINOHead @MODELS.register_module() @@ -126,10 +126,13 @@ def forward(self, mlvl_positional_encodings = [] for feat in mlvl_feats: mlvl_masks.append( - F.interpolate(img_masks[None], - size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + interpolate(img_masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + # `positional_encoding` will return a float tensor by + # default. Convert it to the same dtype as `feat` for pure + # bf16/fp16 training mlvl_positional_encodings.append( - self.positional_encoding(mlvl_masks[-1])) + self.positional_encoding(mlvl_masks[-1]).to(dtype=feat.dtype)) query_embeds = None hs, inter_references, topk_score, topk_anchor, enc_outputs = \ @@ -397,10 +400,13 @@ def forward_aux(self, mlvl_feats, img_metas, aux_targets, head_idx): mlvl_positional_encodings = [] for feat in mlvl_feats: mlvl_masks.append( - F.interpolate(img_masks[None], - size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + interpolate(img_masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + # `positional_encoding` will return a float tensor by + # default. Convert it to the same dtype as `feat` for pure + # bf16/fp16 training mlvl_positional_encodings.append( - self.positional_encoding(mlvl_masks[-1])) + self.positional_encoding(mlvl_masks[-1]).to(dtype=feat.dtype)) query_embeds = None hs, inter_references = self.transformer.forward_aux( diff --git a/projects/CO-DETR/codetr/transformer.py b/projects/CO-DETR/codetr/transformer.py index f2b73858961..46fa194a38a 100644 --- a/projects/CO-DETR/codetr/transformer.py +++ b/projects/CO-DETR/codetr/transformer.py @@ -4,16 +4,16 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmengine.model import BaseModule -from mmengine.model.weight_init import xavier_init -from mmdet.registry import MODELS from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence) -from mmdet.models.layers.transformer import inverse_sigmoid from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import BaseModule +from mmengine.model.weight_init import xavier_init from torch.nn.init import normal_ +from mmdet.models.layers.transformer import inverse_sigmoid +from mmdet.registry import MODELS try: from fairscale.nn.checkpoint import checkpoint_wrapper @@ -308,6 +308,7 @@ def gen_encoder_output_proposals(self, memory, memory_padding_mask, output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = self.enc_output_norm(self.enc_output(output_memory)) + output_proposals = output_proposals.type_as(output_memory) return output_memory, output_proposals @staticmethod @@ -1034,8 +1035,11 @@ def forward(self, reference_points_input = \ reference_points[:, :, None] * valid_ratios[:, None] + # `query_sine_embed` will be float by default. Just convert it to + # the same type as `query` to avoid type mismatch when using pure + # bf16/fp16 training query_sine_embed = self.gen_sineembed_for_position( - reference_points_input[:, :, 0, :], self.embed_dims//2) + reference_points_input[:, :, 0, :], self.embed_dims//2).type_as(query) query_pos = self.ref_point_head(query_sine_embed) query_pos = query_pos.permute(1, 0, 2) @@ -1262,12 +1266,16 @@ def forward_aux(self, topk_coords_unact = inverse_sigmoid((pos_anchors)) reference_points = (pos_anchors) init_reference_out = reference_points + + # get_proposal_pos_embed will return a float tensor by default. + # convert it to the same type as `mlvl_feats` to avoid type mismatch + # during pure fp16/bf16 training if self.num_co_heads > 0: pos_trans_out = self.aux_pos_trans_norm[head_idx]( - self.aux_pos_trans[head_idx](self.get_proposal_pos_embed(topk_coords_unact))) + self.aux_pos_trans[head_idx](self.get_proposal_pos_embed(topk_coords_unact).type_as(mlvl_feats[0]))) query = pos_trans_out if self.with_coord_feat: - query = query + self.pos_feats_norm[head_idx](self.pos_feats_trans[head_idx](pos_feats)) + query = query + self.pos_feats_norm[head_idx](self.pos_feats_trans[head_idx](pos_feats).type_as(mlvl_feats[0])) # decoder query = query.permute(1, 0, 2) @@ -1292,6 +1300,7 @@ def forward_aux(self, from mmcv.cnn import build_norm_layer + @MODELS.register_module() class DetrTransformerEncoder(TransformerLayerSequence): """TransformerEncoder of DETR.