Shortcuts

Source code for mmfewshot.detection.models.detectors.meta_rcnn

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, List, Optional

import torch
from mmcv.runner import auto_fp16
from mmcv.utils import ConfigDict
from mmdet.models.builder import DETECTORS
from torch import Tensor

from .query_support_detector import QuerySupportDetector


[docs]@DETECTORS.register_module() class MetaRCNN(QuerySupportDetector): """Implementation of `Meta R-CNN. <https://arxiv.org/abs/1909.13032>`_. Args: backbone (dict): Config of the backbone for query data. neck (dict | None): Config of the neck for query data and probably for support data. Default: None. support_backbone (dict | None): Config of the backbone for support data only. If None, support and query data will share same backbone. Default: None. support_neck (dict | None): Config of the neck for support data only. Default: None. rpn_head (dict | None): Config of rpn_head. Default: None. roi_head (dict | None): Config of roi_head. Default: None. train_cfg (dict | None): Training config. Useless in CenterNet, but we keep this variable for SingleStageDetector. Default: None. test_cfg (dict | None): Testing config of CenterNet. Default: None. pretrained (str | None): model pretrained path. Default: None. init_cfg (dict | list[dict] | None): Initialization config dict. Default: None """ def __init__(self, backbone: ConfigDict, neck: Optional[ConfigDict] = None, support_backbone: Optional[ConfigDict] = None, support_neck: Optional[ConfigDict] = None, rpn_head: Optional[ConfigDict] = None, roi_head: Optional[ConfigDict] = None, train_cfg: Optional[ConfigDict] = None, test_cfg: Optional[ConfigDict] = None, pretrained: Optional[ConfigDict] = None, init_cfg: Optional[ConfigDict] = None) -> None: super().__init__( backbone=backbone, neck=neck, support_backbone=support_backbone, support_neck=support_neck, rpn_head=rpn_head, roi_head=roi_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained, init_cfg=init_cfg) self.is_model_init = False # save support template features for model initialization, # `_forward_saved_support_dict` used in :func:`forward_model_init`. self._forward_saved_support_dict = { 'gt_labels': [], 'roi_feats': [], } # save processed support template features for inference, # the processed support template features are generated # in :func:`model_init` self.inference_support_dict = {}
[docs] @auto_fp16(apply_to=('img', )) def extract_support_feat(self, img): """Extracting features from support data. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. Returns: list[Tensor]: Features of input image, each item with shape (N, C, H, W). """ feats = self.backbone(img, use_meta_conv=True) if self.support_neck is not None: feats = self.support_neck(feats) return feats
[docs] def forward_model_init(self, img: Tensor, img_metas: List[Dict], gt_bboxes: List[Tensor] = None, gt_labels: List[Tensor] = None, **kwargs): """extract and save support features for model initialization. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: `img_shape`, `scale_factor`, `flip`, and may also contain `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box. Returns: dict: A dict contains following keys: - `gt_labels` (Tensor): class indices corresponding to each feature. - `res5_rois` (list[Tensor]): roi features of res5 layer. """ # `is_model_init` flag will be reset when forward new data. self.is_model_init = False assert len(gt_labels) == img.size( 0), 'Support instance have more than two labels' feats = self.extract_support_feat(img) roi_feat = self.roi_head.extract_support_feats(feats) self._forward_saved_support_dict['gt_labels'].extend(gt_labels) self._forward_saved_support_dict['roi_feats'].extend(roi_feat) return {'gt_labels': gt_labels, 'roi_feat': roi_feat}
[docs] def model_init(self): """process the saved support features for model initialization.""" gt_labels = torch.cat(self._forward_saved_support_dict['gt_labels']) roi_feats = torch.cat(self._forward_saved_support_dict['roi_feats']) class_ids = set(gt_labels.data.tolist()) self.inference_support_dict.clear() for class_id in class_ids: self.inference_support_dict[class_id] = roi_feats[ gt_labels == class_id].mean([0], True) # set the init flag self.is_model_init = True # reset support features buff for k in self._forward_saved_support_dict.keys(): self._forward_saved_support_dict[k].clear()
[docs] def simple_test(self, img: Tensor, img_metas: List[Dict], proposals: Optional[List[Tensor]] = None, rescale: bool = False): """Test without augmentation. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: `img_shape`, `scale_factor`, `flip`, and may also contain `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. proposals (list[Tensor] | None): override rpn proposals with custom proposals. Use when `with_rpn` is False. Default: None. rescale (bool): If True, return boxes in original image space. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ assert self.with_bbox, 'Bbox head must be implemented.' assert len(img_metas) == 1, 'Only support single image inference.' if not self.is_model_init: # process the saved support features self.model_init() query_feats = self.extract_feat(img) if proposals is None: proposal_list = self.rpn_head.simple_test(query_feats, img_metas) else: proposal_list = proposals return self.roi_head.simple_test( query_feats, copy.deepcopy(self.inference_support_dict), proposal_list, img_metas, rescale=rescale)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.