mmfewshot.detection.models.detectors.meta_rcnn 源代码
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional
import torch
from mmcv.runner import auto_fp16
from mmcv.utils import ConfigDict
from mmdet.models.builder import DETECTORS
from torch import Tensor
from .query_support_detector import QuerySupportDetector
[文档]@DETECTORS.register_module()
class MetaRCNN(QuerySupportDetector):
"""Implementation of `Meta R-CNN. <https://arxiv.org/abs/1909.13032>`_.
Args:
backbone (dict): Config of the backbone for query data.
neck (dict | None): Config of the neck for query data and
probably for support data. Default: None.
support_backbone (dict | None): Config of the backbone for
support data only. If None, support and query data will
share same backbone. Default: None.
support_neck (dict | None): Config of the neck for support
data only. Default: None.
rpn_head (dict | None): Config of rpn_head. Default: None.
roi_head (dict | None): Config of roi_head. Default: None.
train_cfg (dict | None): Training config. Useless in CenterNet,
but we keep this variable for SingleStageDetector. Default: None.
test_cfg (dict | None): Testing config of CenterNet. Default: None.
pretrained (str | None): model pretrained path. Default: None.
init_cfg (dict | list[dict] | None): Initialization config dict.
Default: None
"""
def __init__(self,
backbone: ConfigDict,
neck: Optional[ConfigDict] = None,
support_backbone: Optional[ConfigDict] = None,
support_neck: Optional[ConfigDict] = None,
rpn_head: Optional[ConfigDict] = None,
roi_head: Optional[ConfigDict] = None,
train_cfg: Optional[ConfigDict] = None,
test_cfg: Optional[ConfigDict] = None,
pretrained: Optional[ConfigDict] = None,
init_cfg: Optional[ConfigDict] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
support_backbone=support_backbone,
support_neck=support_neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
self.is_model_init = False
# save support template features for model initialization,
# `_forward_saved_support_dict` used in :func:`forward_model_init`.
self._forward_saved_support_dict = {
'gt_labels': [],
'roi_feats': [],
}
# save processed support template features for inference,
# the processed support template features are generated
# in :func:`model_init`
self.inference_support_dict = {}
[文档] @auto_fp16(apply_to=('img', ))
def extract_support_feat(self, img):
"""Extracting features from support data.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
Returns:
list[Tensor]: Features of input image, each item with shape
(N, C, H, W).
"""
feats = self.backbone(img, use_meta_conv=True)
if self.support_neck is not None:
feats = self.support_neck(feats)
return feats
[文档] def forward_model_init(self,
img: Tensor,
img_metas: List[Dict],
gt_bboxes: List[Tensor] = None,
gt_labels: List[Tensor] = None,
**kwargs):
"""extract and save support features for model initialization.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): list of image info dict where each dict
has: `img_shape`, `scale_factor`, `flip`, and may also contain
`filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box.
Returns:
dict: A dict contains following keys:
- `gt_labels` (Tensor): class indices corresponding to each
feature.
- `res5_rois` (list[Tensor]): roi features of res5 layer.
"""
# `is_model_init` flag will be reset when forward new data.
self.is_model_init = False
assert len(gt_labels) == img.size(
0), 'Support instance have more than two labels'
feats = self.extract_support_feat(img)
roi_feat = self.roi_head.extract_support_feats(feats)
self._forward_saved_support_dict['gt_labels'].extend(gt_labels)
self._forward_saved_support_dict['roi_feats'].extend(roi_feat)
return {'gt_labels': gt_labels, 'roi_feat': roi_feat}
[文档] def model_init(self):
"""process the saved support features for model initialization."""
gt_labels = torch.cat(self._forward_saved_support_dict['gt_labels'])
roi_feats = torch.cat(self._forward_saved_support_dict['roi_feats'])
class_ids = set(gt_labels.data.tolist())
self.inference_support_dict.clear()
for class_id in class_ids:
self.inference_support_dict[class_id] = roi_feats[
gt_labels == class_id].mean([0], True)
# set the init flag
self.is_model_init = True
# reset support features buff
for k in self._forward_saved_support_dict.keys():
self._forward_saved_support_dict[k].clear()
[文档] def simple_test(self,
img: Tensor,
img_metas: List[Dict],
proposals: Optional[List[Tensor]] = None,
rescale: bool = False):
"""Test without augmentation.
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): list of image info dict where each dict
has: `img_shape`, `scale_factor`, `flip`, and may also contain
`filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
proposals (list[Tensor] | None): override rpn proposals with
custom proposals. Use when `with_rpn` is False. Default: None.
rescale (bool): If True, return boxes in original image space.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
assert self.with_bbox, 'Bbox head must be implemented.'
assert len(img_metas) == 1, 'Only support single image inference.'
if not self.is_model_init:
# process the saved support features
self.model_init()
query_feats = self.extract_feat(img)
if proposals is None:
proposal_list = self.rpn_head.simple_test(query_feats, img_metas)
else:
proposal_list = proposals
return self.roi_head.simple_test(
query_feats,
copy.deepcopy(self.inference_support_dict),
proposal_list,
img_metas,
rescale=rescale)