• Docs >
  • Module code >
  • mmfewshot.detection.models.roi_heads.bbox_heads.cosine_sim_bbox_head
Shortcuts

Source code for mmfewshot.detection.models.roi_heads.bbox_heads.cosine_sim_bbox_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
import torch.nn as nn
from mmdet.models.builder import HEADS
from mmdet.models.roi_heads import ConvFCBBoxHead
from torch import Tensor


[docs]@HEADS.register_module() class CosineSimBBoxHead(ConvFCBBoxHead): """BBOxHead for `TFA <https://arxiv.org/abs/2003.06957>`_. The code is modified from the official implementation https://github.com/ucbdrive/few-shot-object-detection/ Args: scale (int): Scaling factor of `cls_score`. Default: 20. learnable_scale (bool): Learnable global scaling factor. Default: False. eps (float): Constant variable to avoid division by zero. """ def __init__(self, scale: int = 20, learnable_scale: bool = False, eps: float = 1e-5, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # override the fc_cls in :obj:`ConvFCBBoxHead` if self.with_cls: self.fc_cls = nn.Linear( self.cls_last_dim, self.num_classes + 1, bias=False) # learnable global scaling factor if learnable_scale: self.scale = nn.Parameter(torch.ones(1) * scale) else: self.scale = scale self.eps = eps
[docs] def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Forward function. Args: x (Tensor): Shape of (num_proposals, C, H, W). Returns: tuple: cls_score (Tensor): Cls scores, has shape (num_proposals, num_classes). bbox_pred (Tensor): Box energies / deltas, has shape (num_proposals, 4). """ # shared part if self.num_shared_convs > 0: for conv in self.shared_convs: x = conv(x) if self.num_shared_fcs > 0: if self.with_avg_pool: x = self.avg_pool(x) x = x.flatten(1) for fc in self.shared_fcs: x = self.relu(fc(x)) # separate branches x_cls = x x_reg = x for conv in self.cls_convs: x_cls = conv(x_cls) if x_cls.dim() > 2: if self.with_avg_pool: x_cls = self.avg_pool(x_cls) x_cls = x_cls.flatten(1) for fc in self.cls_fcs: x_cls = self.relu(fc(x_cls)) for conv in self.reg_convs: x_reg = conv(x_reg) if x_reg.dim() > 2: if self.with_avg_pool: x_reg = self.avg_pool(x_reg) x_reg = x_reg.flatten(1) for fc in self.reg_fcs: x_reg = self.relu(fc(x_reg)) bbox_pred = self.fc_reg(x_reg) if self.with_reg else None if x_cls.dim() > 2: x_cls = torch.flatten(x_cls, start_dim=1) # normalize the input x along the `input_size` dimension x_norm = torch.norm(x_cls, p=2, dim=1).unsqueeze(1).expand_as(x) x_cls_normalized = x_cls.div(x_norm + self.eps) # normalize weight with torch.no_grad(): temp_norm = torch.norm( self.fc_cls.weight, p=2, dim=1).unsqueeze(1).expand_as(self.fc_cls.weight) self.fc_cls.weight.div_(temp_norm + self.eps) # calculate and scale cls_score cls_score = self.scale * self.fc_cls( x_cls_normalized) if self.with_cls else None return cls_score, bbox_pred
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.