Shortcuts

Source code for mmfewshot.detection.models.detectors.query_support_detector

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from abc import abstractmethod
from typing import Dict, List, Optional, Union

from mmcv.runner import auto_fp16
from mmcv.utils import ConfigDict
from mmdet.models.builder import (DETECTORS, build_backbone, build_head,
                                  build_neck)
from mmdet.models.detectors import BaseDetector
from torch import Tensor
from typing_extensions import Literal


[docs]@DETECTORS.register_module() class QuerySupportDetector(BaseDetector): """Base class for two-stage detectors in query-support fashion. Query-support detectors typically consisting of a region proposal network and a task-specific regression head. There are two pipelines for query and support data respectively. Args: backbone (dict): Config of the backbone for query data. neck (dict | None): Config of the neck for query data and probably for support data. Default: None. support_backbone (dict | None): Config of the backbone for support data only. If None, support and query data will share same backbone. Default: None. support_neck (dict | None): Config of the neck for support data only. Default: None. rpn_head (dict | None): Config of rpn_head. Default: None. roi_head (dict | None): Config of roi_head. Default: None. train_cfg (dict | None): Training config. Useless in CenterNet, but we keep this variable for SingleStageDetector. Default: None. test_cfg (dict | None): Testing config of CenterNet. Default: None. pretrained (str | None): model pretrained path. Default: None. init_cfg (dict | list[dict] | None): Initialization config dict. Default: None """ def __init__(self, backbone: ConfigDict, neck: Optional[ConfigDict] = None, support_backbone: Optional[ConfigDict] = None, support_neck: Optional[ConfigDict] = None, rpn_head: Optional[ConfigDict] = None, roi_head: Optional[ConfigDict] = None, train_cfg: Optional[ConfigDict] = None, test_cfg: Optional[ConfigDict] = None, pretrained: Optional[ConfigDict] = None, init_cfg: Optional[ConfigDict] = None) -> None: super().__init__(init_cfg) backbone.pretrained = pretrained self.backbone = build_backbone(backbone) self.neck = build_neck(neck) if neck is not None else None # if `support_backbone` is None, then support and query pipeline will # share same backbone. self.support_backbone = build_backbone( support_backbone ) if support_backbone is not None else self.backbone # support neck only forward support data. self.support_neck = build_neck( support_neck) if support_neck is not None else None assert roi_head is not None, 'missing config of roi_head' # when rpn with aggregation neck, the input of rpn will consist of # query and support data. otherwise the input of rpn only # has query data. self.with_rpn = False self.rpn_with_support = False if rpn_head is not None: self.with_rpn = True if rpn_head.get('aggregation_layer', None) is not None: self.rpn_with_support = True rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None rpn_head_ = copy.deepcopy(rpn_head) rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) self.rpn_head = build_head(rpn_head_) if roi_head is not None: # update train and test cfg here for now rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None roi_head.update(train_cfg=rcnn_train_cfg) roi_head.update(test_cfg=test_cfg.rcnn) roi_head.pretrained = pretrained self.roi_head = build_head(roi_head) self.train_cfg = train_cfg self.test_cfg = test_cfg
[docs] @auto_fp16(apply_to=('img', )) def extract_query_feat(self, img: Tensor) -> List[Tensor]: """Extract features of query data. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. Returns: list[Tensor]: Features of support images, each item with shape (N, C, H, W). """ feats = self.backbone(img) if self.with_neck: feats = self.neck(feats) return feats
[docs] def extract_feat(self, img: Tensor) -> List[Tensor]: """Extract features of query data. Args: img (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. Returns: list[Tensor]: Features of query images. """ return self.extract_query_feat(img)
[docs] @abstractmethod def extract_support_feat(self, img: Tensor): """Extract features of support data.""" raise NotImplementedError
[docs] @auto_fp16(apply_to=('img', )) def forward(self, query_data: Optional[Dict] = None, support_data: Optional[Dict] = None, img: Optional[List[Tensor]] = None, img_metas: Optional[List[Dict]] = None, mode: Literal['train', 'model_init', 'test'] = 'train', **kwargs) -> Dict: """Calls one of (:func:`forward_train`, :func:`forward_test` and :func:`forward_model_init`) according to the `mode`. The inputs of forward function would change with the `mode`. - When `mode` is 'train', the input will be query and support data for training. - When `mode` is 'model_init', the input will be support template data at least including (img, img_metas). - When `mode` is 'test', the input will be test data at least including (img, img_metas). Args: query_data (dict): Used for :func:`forward_train`. Dict of query data and data info where each dict has: `img`, `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. Default: None. support_data (dict): Used for :func:`forward_train`. Dict of support data and data info dict where each dict has: `img`, `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. Default: None. img (list[Tensor]): Used for func:`forward_test` or :func:`forward_model_init`. List of tensors of shape (1, C, H, W). Typically these should be mean centered and std scaled. Default: None. img_metas (list[dict]): Used for func:`forward_test` or :func:`forward_model_init`. List of image info dict where each dict has: `img_shape`, `scale_factor`, `flip`, and may also contain `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. For details on the values of these keys, see :class:`mmdet.datasets.pipelines.Collect`. Default: None. mode (str): Indicate which function to call. Options are 'train', 'model_init' and 'test'. Default: 'train'. """ if mode == 'train': return self.forward_train(query_data, support_data, **kwargs) elif mode == 'model_init': return self.forward_model_init(img, img_metas, **kwargs) elif mode == 'test': return self.forward_test(img, img_metas, **kwargs) else: raise ValueError( f'invalid forward mode {mode}, ' f'only support `train`, `model_init` and `test` now')
[docs] def train_step(self, data: Dict, optimizer: Union[object, Dict]) -> Dict: """The iteration step during training. This method defines an iteration step during training, except for the back propagation and optimizer updating, which are done in an optimizer hook. Note that in some complicated cases or models, the whole process including back propagation and optimizer updating is also defined in this method, such as GAN. For most of query-support detectors, the batch size denote the batch size of query data. Args: data (dict): The output of dataloader. optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of runner is passed to ``train_step()``. This argument is unused and reserved. Returns: dict: It should contain at least 3 keys: ``loss``, ``log_vars``, ``num_samples``. - ``loss`` is a tensor for back propagation, which can be a weighted sum of multiple losses. - ``log_vars`` contains all the variables to be sent to the logger. - ``num_samples`` indicates the batch size (when the model is DDP, it means the batch size on each GPU), which is used for averaging the logs. """ losses = self(**data) loss, log_vars = self._parse_losses(losses) # For most of query-support detectors, the batch size denote the # batch size of query data. outputs = dict( loss=loss, log_vars=log_vars, num_samples=len(data['query_data']['img_metas'])) return outputs
[docs] def val_step(self, data: Dict, optimizer: Optional[Union[object, Dict]] = None) -> Dict: """The iteration step during validation. This method shares the same signature as :func:`train_step`, but used during val epochs. Note that the evaluation after training epochs is not implemented with this method, but an evaluation hook. """ losses = self(**data) loss, log_vars = self._parse_losses(losses) # For most of query-support detectors, the batch size denote the # batch size of query data. outputs = dict( loss=loss, log_vars=log_vars, num_samples=len(data['query_data']['img_metas'])) return outputs
[docs] def forward_train(self, query_data: Dict, support_data: Dict, proposals: Optional[List] = None, **kwargs) -> Dict: """Forward function for training. Args: query_data (dict): In most cases, dict of query data contains: `img`, `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. support_data (dict): In most cases, dict of support data contains: `img`, `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. proposals (list): Override rpn proposals with custom proposals. Use when `with_rpn` is False. Default: None. Returns: dict[str, Tensor]: a dictionary of loss components """ query_img = query_data['img'] support_img = support_data['img'] query_feats = self.extract_query_feat(query_img) support_feats = self.extract_support_feat(support_img) losses = dict() # RPN forward and loss if self.with_rpn: proposal_cfg = self.train_cfg.get('rpn_proposal', self.test_cfg.rpn) if self.rpn_with_support: rpn_losses, proposal_list = self.rpn_head.forward_train( query_feats, support_feats, query_img_metas=query_data['img_metas'], query_gt_bboxes=query_data['gt_bboxes'], query_gt_labels=None, query_gt_bboxes_ignore=query_data.get( 'gt_bboxes_ignore', None), support_img_metas=support_data['img_metas'], support_gt_bboxes=support_data['gt_bboxes'], support_gt_labels=support_data['gt_labels'], support_gt_bboxes_ignore=support_data.get( 'gt_bboxes_ignore', None), proposal_cfg=proposal_cfg) else: rpn_losses, proposal_list = self.rpn_head.forward_train( query_feats, copy.deepcopy(query_data['img_metas']), copy.deepcopy(query_data['gt_bboxes']), gt_labels=None, gt_bboxes_ignore=copy.deepcopy( query_data.get('gt_bboxes_ignore', None)), proposal_cfg=proposal_cfg) losses.update(rpn_losses) else: proposal_list = proposals roi_losses = self.roi_head.forward_train( query_feats, support_feats, proposals=proposal_list, query_img_metas=query_data['img_metas'], query_gt_bboxes=query_data['gt_bboxes'], query_gt_labels=query_data['gt_labels'], query_gt_bboxes_ignore=query_data.get('gt_bboxes_ignore', None), support_img_metas=support_data['img_metas'], support_gt_bboxes=support_data['gt_bboxes'], support_gt_labels=support_data['gt_labels'], support_gt_bboxes_ignore=support_data.get('gt_bboxes_ignore', None), **kwargs) losses.update(roi_losses) return losses
[docs] def simple_test(self, img: Tensor, img_metas: List[Dict], proposals: Optional[List[Tensor]] = None, rescale: bool = False): """Test without augmentation.""" raise NotImplementedError
[docs] def aug_test(self, **kwargs): """Test with augmentation.""" raise NotImplementedError
[docs] @abstractmethod def forward_model_init(self, img: Tensor, img_metas: List[Dict], gt_bboxes: List[Tensor] = None, gt_labels: List[Tensor] = None, **kwargs): """extract and save support features for model initialization.""" raise NotImplementedError
[docs] @abstractmethod def model_init(self, **kwargs): """process the saved support features for model initialization.""" raise NotImplementedError
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.