• Docs >
  • 模块代码 >
  • mmfewshot.detection.models.roi_heads.bbox_heads.contrastive_bbox_head

mmfewshot.detection.models.roi_heads.bbox_heads.contrastive_bbox_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.roi_heads import ConvFCBBoxHead
from torch import Tensor

[文档]@HEADS.register_module() class ContrastiveBBoxHead(ConvFCBBoxHead): """BBoxHead for `FSCE <https://arxiv.org/abs/2103.05950>`_. Args: mlp_head_channels (int): Output channels of contrast branch mlp. Default: 128. with_weight_decay (bool): Whether to decay loss weight. Default: False. loss_contrast (dict): Config of contrast loss. scale (int): Scaling factor of `cls_score`. Default: 20. learnable_scale (bool): Learnable global scaling factor. Default: False. eps (float): Constant variable to avoid division by zero. """ def __init__(self, mlp_head_channels: int = 128, with_weight_decay: bool = False, loss_contrast: Dict = dict( type='SupervisedContrastiveLoss', temperature=0.1, iou_threshold=0.5, loss_weight=1.0, reweight_type='none'), scale: int = 20, learnable_scale: bool = False, eps: float = 1e-5, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # override the fc_cls in :obj:`ConvFCBBoxHead` if self.with_cls: self.fc_cls = nn.Linear( self.cls_last_dim, self.num_classes + 1, bias=False) # learnable global scaling factor if learnable_scale: self.scale = nn.Parameter(torch.ones(1) * scale) else: self.scale = scale self.mlp_head_channels = mlp_head_channels self.with_weight_decay = with_weight_decay self.eps = eps # This will be updated by :class:`ContrastiveLossDecayHook` # in the training phase. self._decay_rate = 1.0 self.gamma = 1 self.contrastive_head = nn.Sequential( nn.Linear(self.fc_out_channels, self.fc_out_channels), nn.ReLU(inplace=True), nn.Linear(self.fc_out_channels, mlp_head_channels)) self.contrast_loss = build_loss(copy.deepcopy(loss_contrast))
[文档] def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """Forward function. Args: x (Tensor): Shape of (num_proposals, C, H, W). Returns: tuple: cls_score (Tensor): Cls scores, has shape (num_proposals, num_classes). bbox_pred (Tensor): Box energies / deltas, has shape (num_proposals, 4). contrast_feat (Tensor): Box features for contrast loss, has shape (num_proposals, C). """ # shared part if self.num_shared_convs > 0: for conv in self.shared_convs: x = conv(x) if self.num_shared_fcs > 0: if self.with_avg_pool: x = self.avg_pool(x) x = x.flatten(1) for fc in self.shared_fcs: x = self.relu(fc(x)) # separate branches x_cls = x x_reg = x x_contra = x for conv in self.cls_convs: x_cls = conv(x_cls) if x_cls.dim() > 2: if self.with_avg_pool: x_cls = self.avg_pool(x_cls) x_cls = x_cls.flatten(1) for fc in self.cls_fcs: x_cls = self.relu(fc(x_cls)) for conv in self.reg_convs: x_reg = conv(x_reg) if x_reg.dim() > 2: if self.with_avg_pool: x_reg = self.avg_pool(x_reg) x_reg = x_reg.flatten(1) for fc in self.reg_fcs: x_reg = self.relu(fc(x_reg)) # reg branch bbox_pred = self.fc_reg(x_reg) if self.with_reg else None # cls branch if x_cls.dim() > 2: x_cls = torch.flatten(x_cls, start_dim=1) # normalize the input x along the `input_size` dimension x_norm = torch.norm(x_cls, p=2, dim=1).unsqueeze(1).expand_as(x) x_cls_normalized = x_cls.div(x_norm + self.eps) # normalize weight with torch.no_grad(): temp_norm = torch.norm( self.fc_cls.weight, p=2, dim=1).unsqueeze(1).expand_as(self.fc_cls.weight) self.fc_cls.weight.div_(temp_norm + self.eps) # calculate and scale cls_score cls_score = self.scale * self.fc_cls( x_cls_normalized) if self.with_cls else None # contrastive branch contrast_feat = self.contrastive_head(x_contra) contrast_feat = F.normalize(contrast_feat, dim=1) return cls_score, bbox_pred, contrast_feat
[文档] def set_decay_rate(self, decay_rate: float) -> None: """Contrast loss weight decay hook will set the `decay_rate` according to iterations. Args: decay_rate (float): Decay rate for weight decay. """ self._decay_rate = decay_rate
[文档] @force_fp32(apply_to=('contrast_feat')) def loss_contrast(self, contrast_feat: Tensor, proposal_ious: Tensor, labels: Tensor, reduction_override: Optional[str] = None) -> Dict: """Loss for contract. Args: contrast_feat (tensor): BBox features with shape (N, C) used for contrast loss. proposal_ious (tensor): IoU between proposal and ground truth corresponding to each BBox features with shape (N). labels (tensor): Labels for each BBox features with shape (N). reduction_override (str | None): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Default: None. Returns: Dict: The calculated loss. """ losses = dict() if self.with_weight_decay: decay_rate = self._decay_rate else: decay_rate = None losses['loss_contrast'] = self.contrast_loss( contrast_feat, labels, proposal_ious, decay_rate=decay_rate, reduction_override=reduction_override) return losses
