mmfewshot.detection.models.roi_heads.bbox_heads.multi_relation_bbox_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet.models.builder import HEADS
from mmdet.models.losses import accuracy
from mmdet.models.roi_heads import BBoxHead
from torch import Tensor

[文档]@HEADS.register_module() class MultiRelationBBoxHead(BBoxHead): """BBox head for `Attention RPN <https://arxiv.org/abs/1908.01998>`_. Args: patch_relation (bool): Whether use patch_relation head for classification. Following the official implementation, `patch_relation` always be True, because only patch relation head contain regression head. Default: True. local_correlation (bool): Whether use local_correlation head for classification. Default: True. global_relation (bool): Whether use global_relation head for classification. Default: True. """ def __init__(self, patch_relation: bool = True, local_correlation: bool = True, global_relation: bool = True, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # remove unused parameters inherited from BBoxHead if hasattr(self, 'fc_cls'): del self.fc_cls if hasattr(self, 'fc_reg'): del self.fc_reg # following the official implementation patch relation must be True, # because only patch relation head contain regression head self.patch_relation = True self.local_correlation = local_correlation self.global_relation = global_relation if self.patch_relation: self.patch_relation_branch = nn.Sequential( nn.Conv2d( self.in_channels * 2, int(self.in_channels / 4), 1, padding=0, bias=False), nn.ReLU(inplace=True), # 7x7 -> 5x5 nn.AvgPool2d(kernel_size=3, stride=1), # 5x5 -> 3x3 nn.Conv2d( int(self.in_channels / 4), int(self.in_channels / 4), 3, padding=0, bias=False), nn.ReLU(inplace=True), nn.Conv2d( int(self.in_channels / 4), self.in_channels, 1, padding=0, bias=False), nn.ReLU(inplace=True), # 3x3 -> 1x1 nn.AvgPool2d(kernel_size=3, stride=1)) self.patch_relation_fc_reg = nn.Linear(self.in_channels, 4) self.patch_relation_fc_cls = nn.Linear(self.in_channels, 2) if self.local_correlation: self.local_correlation_branch = nn.Sequential( nn.Conv2d( self.in_channels, self.in_channels, 1, padding=0, bias=False)) self.local_correlation_fc_cls = nn.Linear(self.in_channels, 2) if self.global_relation: self.global_relation_avgpool = nn.AvgPool2d(7) self.global_relation_branch = nn.Sequential( nn.Linear(self.in_channels * 2, self.in_channels), nn.ReLU(inplace=True), nn.Linear(self.in_channels, self.in_channels), nn.ReLU(inplace=True)) self.global_relation_fc_cls = nn.Linear(self.in_channels, 2)
[文档] def forward(self, query_feat: Tensor, support_feat: Tensor) -> Tuple[Tensor, Tensor]: """Forward function. Args: query_feat (Tensor): Shape of (num_proposals, C, H, W). support_feat (Tensor): Shape of (1, C, H, W). Returns: tuple: cls_score (Tensor): Cls scores, has shape (num_proposals, num_classes). bbox_pred (Tensor): Box energies / deltas, has shape (num_proposals, 4). """ # global_relation if self.global_relation: global_query_feat = self.global_relation_avgpool( query_feat).squeeze(3).squeeze(2) global_support_feat = self.global_relation_avgpool( support_feat).squeeze(3).squeeze(2).expand_as( global_query_feat) global_feat = \ torch.cat((global_query_feat, global_support_feat), 1) global_feat = self.global_relation_branch(global_feat) global_relation_cls_score = \ self.global_relation_fc_cls(global_feat) # local_correlation if self.local_correlation: local_query_feat = self.local_correlation_branch(query_feat) local_support_feat = self.local_correlation_branch(support_feat) local_feat = F.conv2d( local_query_feat, local_support_feat.permute(1, 0, 2, 3), groups=2048) local_feat = F.relu(local_feat, inplace=True).squeeze(3).squeeze(2) local_correlation_cls_score = self.local_correlation_fc_cls( local_feat) # patch_relation if self.patch_relation: patch_feat = torch.cat( (query_feat, support_feat.expand_as(query_feat)), 1) # 7x7 -> 1x1 patch_feat = self.patch_relation_branch(patch_feat) patch_feat = patch_feat.squeeze(3).squeeze(2) patch_relation_cls_score = self.patch_relation_fc_cls(patch_feat) patch_relation_bbox_pred = self.patch_relation_fc_reg(patch_feat) # aggregate multi relation result # following the official implementation patch, # only patch relation head contain regression head bbox_pred_all = patch_relation_bbox_pred cls_score_all = patch_relation_cls_score if self.local_correlation: cls_score_all += local_correlation_cls_score if self.global_relation: cls_score_all += global_relation_cls_score return cls_score_all, bbox_pred_all
[文档] @force_fp32(apply_to=('cls_scores', 'bbox_preds')) def loss( self, cls_scores: Tensor, bbox_preds: Tensor, rois: Tensor, labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, bbox_weights: Tensor, num_pos_pair_samples: int, reduction_override: Optional[str] = None, sample_fractions: Sequence[Union[int, float]] = (1, 2, 1) ) -> Dict: """Compute losses of the head. Args: cls_scores (Tensor): Box scores with shape of (num_proposals, num_classes) bbox_preds (Tensor): Box energies / deltas with shape of (num_proposals, num_classes * 4) rois (Tensor): shape (N, 4) or (N, 5) labels (Tensor): Labels of proposals with shape (num_proposals). label_weights (Tensor): Label weights of proposals with shape (num_proposals). bbox_targets (Tensor): BBox regression targets of each proposal weight with shape (num_proposals, num_classes * 4). bbox_weights (Tensor): BBox regression loss weights of each proposal with shape (num_proposals, num_classes * 4). num_pos_pair_samples (int): Number of samples from positive pairs. reduction_override (str | None): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Default: None. sample_fractions (Sequence[int | float]): Fractions of positive samples, negative samples from positive pair, negative samples from negative pair. Default: (1, 2, 1). Returns: dict[str, Tensor]: A dictionary of loss components. """ losses = dict() # fg bg sampling num_instances = labels.size(0) fg_samples_inds = torch.nonzero( labels == 0, as_tuple=False).squeeze(-1) bg_samples_inds = torch.nonzero( labels == 1, as_tuple=False).squeeze(-1) bg_cls_scores = cls_scores[bg_samples_inds, :] num_pos_pair_bg_samples = max( 1, min(fg_samples_inds.shape[0] * sample_fractions[1], int(num_instances / sum(sample_fractions)))) num_neg_pair_samples = max( 1, min(fg_samples_inds.shape[0] * sample_fractions[2], num_pos_pair_bg_samples)) _, sorted_inds = torch.sort(bg_cls_scores[:, 0], descending=True) sorted_bg_samples_inds = bg_samples_inds[sorted_inds] pos_pair_bg_samples_inds = sorted_bg_samples_inds[ sorted_bg_samples_inds < num_pos_pair_samples][:num_pos_pair_bg_samples] neg_pair_samples_inds = sorted_bg_samples_inds[ sorted_bg_samples_inds >= num_pos_pair_samples][:num_neg_pair_samples] topk_inds = torch.cat( [fg_samples_inds, pos_pair_bg_samples_inds, neg_pair_samples_inds], dim=0) if cls_scores is not None: if cls_scores.numel() > 0: # cls_inds resample the rois to get final classification loss losses['loss_cls'] = self.loss_cls( cls_scores[topk_inds], labels[topk_inds], label_weights[topk_inds], avg_factor=len(topk_inds), reduction_override=reduction_override) losses['acc'] = accuracy(cls_scores, labels) if bbox_preds is not None: bg_class_ind = self.num_classes # 0~self.num_classes-1 are FG, self.num_classes is BG pos_inds = (labels >= 0) & (labels < bg_class_ind) # do not perform bounding box regression for BG anymore. if pos_inds.any(): if self.reg_decoded_bbox: # When the regression loss (e.g. `IouLoss`, # `GIouLoss`, `DIouLoss`) is applied directly on # the decoded bounding boxes, it decodes the # already encoded coordinates to absolute format. bbox_preds = self.bbox_coder.decode( rois[:, 1:], bbox_preds) if self.reg_class_agnostic: pos_bbox_pred = bbox_preds.view( bbox_preds.size(0), 4)[pos_inds.type(torch.bool)] else: pos_bbox_pred = bbox_preds.view( bbox_preds.size(0), -1, 4)[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]] losses['loss_bbox'] = self.loss_bbox( pos_bbox_pred, bbox_targets[pos_inds.type(torch.bool)], bbox_weights[pos_inds.type(torch.bool)], avg_factor=num_instances, reduction_override=reduction_override) else: losses['loss_bbox'] = bbox_preds[pos_inds].sum() return losses
