Source code for mmfewshot.detection.models.roi_heads.bbox_heads.cosine_sim_bbox_head
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn as nn
from mmdet.models.builder import HEADS
from mmdet.models.roi_heads import ConvFCBBoxHead
from torch import Tensor
[docs]@HEADS.register_module()
class CosineSimBBoxHead(ConvFCBBoxHead):
"""BBOxHead for `TFA <https://arxiv.org/abs/2003.06957>`_.
The code is modified from the official implementation
https://github.com/ucbdrive/few-shot-object-detection/
Args:
scale (int): Scaling factor of `cls_score`. Default: 20.
learnable_scale (bool): Learnable global scaling factor.
Default: False.
eps (float): Constant variable to avoid division by zero.
"""
def __init__(self,
scale: int = 20,
learnable_scale: bool = False,
eps: float = 1e-5,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
# override the fc_cls in :obj:`ConvFCBBoxHead`
if self.with_cls:
self.fc_cls = nn.Linear(
self.cls_last_dim, self.num_classes + 1, bias=False)
# learnable global scaling factor
if learnable_scale:
self.scale = nn.Parameter(torch.ones(1) * scale)
else:
self.scale = scale
self.eps = eps
[docs] def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Forward function.
Args:
x (Tensor): Shape of (num_proposals, C, H, W).
Returns:
tuple:
cls_score (Tensor): Cls scores, has shape
(num_proposals, num_classes).
bbox_pred (Tensor): Box energies / deltas, has shape
(num_proposals, 4).
"""
# shared part
if self.num_shared_convs > 0:
for conv in self.shared_convs:
x = conv(x)
if self.num_shared_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.flatten(1)
for fc in self.shared_fcs:
x = self.relu(fc(x))
# separate branches
x_cls = x
x_reg = x
for conv in self.cls_convs:
x_cls = conv(x_cls)
if x_cls.dim() > 2:
if self.with_avg_pool:
x_cls = self.avg_pool(x_cls)
x_cls = x_cls.flatten(1)
for fc in self.cls_fcs:
x_cls = self.relu(fc(x_cls))
for conv in self.reg_convs:
x_reg = conv(x_reg)
if x_reg.dim() > 2:
if self.with_avg_pool:
x_reg = self.avg_pool(x_reg)
x_reg = x_reg.flatten(1)
for fc in self.reg_fcs:
x_reg = self.relu(fc(x_reg))
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
if x_cls.dim() > 2:
x_cls = torch.flatten(x_cls, start_dim=1)
# normalize the input x along the `input_size` dimension
x_norm = torch.norm(x_cls, p=2, dim=1).unsqueeze(1).expand_as(x)
x_cls_normalized = x_cls.div(x_norm + self.eps)
# normalize weight
with torch.no_grad():
temp_norm = torch.norm(
self.fc_cls.weight, p=2,
dim=1).unsqueeze(1).expand_as(self.fc_cls.weight)
self.fc_cls.weight.div_(temp_norm + self.eps)
# calculate and scale cls_score
cls_score = self.scale * self.fc_cls(
x_cls_normalized) if self.with_cls else None
return cls_score, bbox_pred