Shortcuts

Source code for mmfewshot.detection.models.detectors.attention_rpn_detector

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import numpy as np
import torch
from mmcv.runner import auto_fp16
from mmcv.utils import ConfigDict
from mmdet.core import bbox2roi
from mmdet.models.builder import DETECTORS
from torch import Tensor

from .query_support_detector import QuerySupportDetector


[docs]@DETECTORS.register_module() class AttentionRPNDetector(QuerySupportDetector): """Implementation of `AttentionRPN <https://arxiv.org/abs/1908.01998>`_. Args: backbone (dict): Config of the backbone for query data. neck (dict | None): Config of the neck for query data and probably for support data. Default: None. support_backbone (dict | None): Config of the backbone for support data only. If None, support and query data will share same backbone. Default: None. support_neck (dict | None): Config of the neck for support data only. Default: None. rpn_head (dict | None): Config of rpn_head. Default: None. roi_head (dict | None): Config of roi_head. Default: None. train_cfg (dict | None): Training config. Useless in CenterNet, but we keep this variable for SingleStageDetector. Default: None. test_cfg (dict | None): Testing config of CenterNet. Default: None. pretrained (str | None): model pretrained path. Default: None. init_cfg (dict | list[dict] | None): Initialization config dict. Default: None. """ def __init__(self, backbone: ConfigDict, neck: Optional[ConfigDict] = None, support_backbone: Optional[ConfigDict] = None, support_neck: Optional[ConfigDict] = None, rpn_head: Optional[ConfigDict] = None, roi_head: Optional[ConfigDict] = None, train_cfg: Optional[ConfigDict] = None, test_cfg: Optional[ConfigDict] = None, pretrained: Optional[ConfigDict] = None, init_cfg: Optional[ConfigDict] = None) -> None: super().__init__( backbone=backbone, neck=neck, support_backbone=support_backbone, support_neck=support_neck, rpn_head=rpn_head, roi_head=roi_head, train_cfg=train_cfg, test_cfg=test_cfg, pretrained=pretrained, init_cfg=init_cfg) self.is_model_init = False # save support template features for model initialization, # `_forward_saved_support_dict` used in :func:`forward_model_init`. self._forward_saved_support_dict = { 'gt_labels': [], 'res4_roi_feats': [], 'res5_roi_feats': [] } # save processed support template features for inference, # the processed support template features are generated # in :func:`model_init` self.inference_support_dict = {}
[docs] @auto_fp16(apply_to=('img', )) def extract_support_feat(self, img: Tensor) -> List[Tensor]: """Extract features of support data. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. Returns: list[Tensor]: Features of support images, each item with shape (N, C, H, W). """ feats = self.support_backbone(img) if self.support_neck is not None: feats = self.support_neck(feats) return feats
[docs] def forward_model_init(self, img: Tensor, img_metas: List[Dict], gt_bboxes: List[Tensor] = None, gt_labels: List[Tensor] = None, **kwargs) -> Dict: """Extract and save support features for model initialization. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: `img_shape`, `scale_factor`, `flip`, and may also contain `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_labels (list[Tensor]): class indices corresponding to each box. Returns: dict: A dict contains following keys: - `gt_labels` (Tensor): class indices corresponding to each feature. - `res4_roi_feat` (Tensor): roi features of res4 layer. - `res5_roi_feat` (Tensor): roi features of res5 layer. """ self.is_model_init = False # extract support template features will reset `is_model_init` flag assert gt_bboxes is not None and gt_labels is not None, \ 'forward support template require gt_bboxes and gt_labels.' assert len(gt_labels) == img.size(0), \ 'Support instance have more than two labels' feats = self.extract_support_feat(img) rois = bbox2roi([bboxes for bboxes in gt_bboxes]) res4_roi_feat = self.rpn_head.extract_roi_feat(feats, rois) res5_roi_feat = self.roi_head.extract_roi_feat(feats, rois) self._forward_saved_support_dict['gt_labels'].extend(gt_labels) self._forward_saved_support_dict['res4_roi_feats'].append( res4_roi_feat) self._forward_saved_support_dict['res5_roi_feats'].append( res5_roi_feat) return { 'gt_labels': gt_labels, 'res4_roi_feats': res4_roi_feat, 'res5_roi_feats': res5_roi_feat }
[docs] def model_init(self) -> None: """process the saved support features for model initialization.""" self.inference_support_dict.clear() gt_labels = torch.cat(self._forward_saved_support_dict['gt_labels']) # used for attention rpn head res4_roi_feats = torch.cat( self._forward_saved_support_dict['res4_roi_feats']) # used for multi relation head res5_roi_feats = torch.cat( self._forward_saved_support_dict['res5_roi_feats']) class_ids = set(gt_labels.data.tolist()) for class_id in class_ids: self.inference_support_dict[class_id] = { 'res4_roi_feats': res4_roi_feats[gt_labels == class_id].mean([0, 2, 3], True), 'res5_roi_feats': res5_roi_feats[gt_labels == class_id].mean([0], True) } # set the init flag self.is_model_init = True # clear support dict for k in self._forward_saved_support_dict.keys(): self._forward_saved_support_dict[k].clear()
[docs] def simple_test(self, img: Tensor, img_metas: List[Dict], proposals: Optional[List[Tensor]] = None, rescale: bool = False) -> List[List[np.ndarray]]: """Test without augmentation. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. img_metas (list[dict]): list of image info dict where each dict has: `img_shape`, `scale_factor`, `flip`, and may also contain `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. proposals (list[Tensor] | None): override rpn proposals with custom proposals. Use when `with_rpn` is False. Default: None. rescale (bool): If True, return boxes in original image space. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ assert self.with_bbox, 'Bbox head must be implemented.' assert len(img_metas) == 1, 'Only support single image inference.' if (self.inference_support_dict == {}) or (not self.is_model_init): # process the saved support features self.model_init() results_dict = {} query_feats = self.extract_feat(img) for class_id in self.inference_support_dict.keys(): support_res4_roi_feat = \ self.inference_support_dict[class_id]['res4_roi_feats'] support_res5_roi_feat = \ self.inference_support_dict[class_id]['res5_roi_feats'] if proposals is None: proposal_list = self.rpn_head.simple_test( query_feats, support_res4_roi_feat, img_metas) else: proposal_list = proposals results_dict[class_id] = self.roi_head.simple_test( query_feats, support_res5_roi_feat, proposal_list, img_metas, rescale=rescale) results = [ results_dict[i][0][0] for i in sorted(results_dict.keys()) if len(results_dict[i]) ] return [results]
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.