Shortcuts

Source code for mmfewshot.detection.models.roi_heads.multi_relation_roi_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from mmcv.utils import ConfigDict
from mmdet.core import bbox2result, bbox2roi
from mmdet.models.builder import HEADS
from mmdet.models.roi_heads import StandardRoIHead
from torch import Tensor


[docs]@HEADS.register_module() class MultiRelationRoIHead(StandardRoIHead): """Roi head for `AttentionRPN <https://arxiv.org/abs/1908.01998>`_. Args: num_support_ways (int): Number of sampled classes (pos + neg). num_support_shots (int): Number of shot for each classes. sample_fractions (Sequence[int | float]): Fractions of positive samples, negative samples from positive pair, negative samples from negative pair. Default: (1, 2, 1). """ def __init__(self, num_support_ways: int = 2, num_support_shots: int = 5, sample_fractions: Sequence[Union[float, int]] = (1, 2, 1), **kwargs): super().__init__(**kwargs) self.num_support_ways = num_support_ways self.num_support_shots = num_support_shots self.sample_fractions = sample_fractions
[docs] def forward_train(self, query_feats: List[Tensor], support_feats: List[Tensor], proposals: List[Tensor], query_img_metas: List[Dict], query_gt_bboxes: List[Tensor], query_gt_labels: List[Tensor], support_gt_bboxes: List[Tensor], query_gt_bboxes_ignore: Optional[List[Tensor]] = None, **kwargs) -> Dict: """All arguments excepted proposals are passed in tuple of (query, support). Args: query_feats (list[Tensor]): List of query features, each item with shape (N, C, H, W). support_feats (list[Tensor]): List of support features, each item with shape (N, C, H, W). proposals (list[Tensor]): List of region proposals with positive and negative query-support pairs. query_img_metas (list[dict]): List of query image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. For details on the values of these keys see `mmdet/datasets/pipelines/formatting.py:Collect`. query_gt_bboxes (list[Tensor]): Ground truth bboxes for each query image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. query_gt_labels (list[Tensor]): Class indices corresponding to each bbox from query image. support_gt_bboxes (list[Tensor]): Ground truth bboxes for each support image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. query_gt_bboxes_ignore (None | list[Tensor]): Specify which bounding boxes from query image can be ignored when computing the loss. Default: None. Returns: dict[str, Tensor]: A dictionary of loss components. """ # prepare contrastive training data batch_size = len(query_img_metas) repeat_query_feats = [] for lvl in range(len(query_feats)): repeat_query_feats.append([query_feats[lvl]]) for i in range(batch_size): query_gt_labels[i] = torch.zeros_like(query_gt_labels[i]) query_gt_labels.extend([torch.zeros_like(query_gt_labels[i])] * (self.num_support_ways - 1)) for lvl in range(len(query_feats)): repeat_query_feats[lvl].append( query_feats[lvl][i].unsqueeze(0).repeat( (self.num_support_ways - 1, 1, 1, 1))) query_img_metas.extend([query_img_metas[i]] * (self.num_support_ways - 1)) query_gt_bboxes.extend([query_gt_bboxes[i]] * (self.num_support_ways - 1)) for lvl in range(len(repeat_query_feats)): repeat_query_feats[lvl] = torch.cat(repeat_query_feats[lvl]) sampling_results = [] # assign gts and sample proposals if self.with_bbox: num_imgs = len(query_img_metas) assert num_imgs == len(proposals), \ 'proposals should keep same length with batch_size*(pos+neg).' if query_gt_bboxes_ignore is None: query_gt_bboxes_ignore = [None for _ in range(num_imgs)] for i in range(num_imgs): assign_result = self.bbox_assigner.assign( proposals[i], query_gt_bboxes[i], query_gt_bboxes_ignore[i], query_gt_labels[i]) sampling_result = self.bbox_sampler.sample( assign_result, proposals[i], query_gt_bboxes[i], query_gt_labels[i], feats=[ lvl_feat[i][None] for lvl_feat in repeat_query_feats ]) sampling_results.append(sampling_result) losses = dict() # bbox head forward and loss if self.with_bbox: bbox_results = self._bbox_forward_train( batch_size, repeat_query_feats, support_feats, sampling_results, query_gt_bboxes, query_gt_labels, support_gt_bboxes) if bbox_results is not None: losses.update(bbox_results['loss_bbox']) return losses
[docs] def extract_roi_feat(self, feats: List[Tensor], rois: Tensor) -> Tensor: """Extract BBOX feature function used in both training and testing. Args: feats (list[Tensor]): Features from backbone, each item with shape (N, C, W, H). rois (Tensor): shape (num_proposals, 5). Returns: Tensor: Roi features with shape (num_proposals, C). """ roi_feats = self.bbox_roi_extractor( feats[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: roi_feats = self.shared_head(roi_feats) return roi_feats
def _bbox_forward(self, query_roi_feats: Tensor, support_roi_feats: Tensor) -> Dict: """Box head forward function used in both training and testing. Args: query_roi_feats (Tensor): Roi features with shape (N*K, C). support_roi_feats (Tensor): Roi features with shape (N, C). Returns: dict: A dictionary of predicted results. """ cls_score, bbox_pred = [], [] batch_size = len(support_roi_feats) for query_rois_feat, support_rois_feat in zip( torch.chunk(query_roi_feats, batch_size), support_roi_feats): cls_score_single, bbox_pred_single = self.bbox_head( query_rois_feat, support_rois_feat) cls_score.append(cls_score_single) bbox_pred.append(bbox_pred_single) bbox_results = dict( cls_score=torch.cat(cls_score), bbox_pred=torch.cat(bbox_pred)) return bbox_results def _bbox_forward_train(self, batch_size: int, query_feats: List[Tensor], support_feats: List[Tensor], sampling_results: object, query_gt_bboxes: List[Tensor], query_gt_labels: List[Tensor], support_gt_bboxes: List[Tensor]) -> Dict: """Forward function and calculate loss for bbox head in training. Args: batch_size (int): Batch size. query_feats (list[Tensor]): List of query features, each item with shape (N, C, H, W). support_feats (list[Tensor]): List of support features, each item with shape (N, C, H, W). sampling_results (obj:`SamplingResult`): Sampling results. query_gt_bboxes (list[Tensor]): Ground truth bboxes for query image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. query_gt_labels (list[Tensor]): Class indices corresponding to each bbox for query images. support_gt_bboxes (list[Tensor]): Ground truth bboxes for support image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. Returns: dict: Predicted results and losses. """ query_rois = bbox2roi([res.bboxes for res in sampling_results]) query_roi_feats = self.extract_roi_feat(query_feats, query_rois) support_rois = bbox2roi([bboxes for bboxes in support_gt_bboxes]) support_roi_feats = self.extract_roi_feat(support_feats, support_rois) avg_support_roi_feats = [ support_roi_feats[i * self.num_support_shots:(i + 1) * self.num_support_shots].mean(0, True) for i in range( support_roi_feats.size(0) // self.num_support_shots) ] if batch_size > 1: pos_avg_support_roi_feats = [] neg_avg_support_roi_feats = [] for b in range(batch_size): start = b * self.num_support_ways end = (b + 1) * self.num_support_ways pos_avg_support_roi_feats.extend( avg_support_roi_feats[start:start + 1]) neg_avg_support_roi_feats.extend(avg_support_roi_feats[start + 1:end]) avg_support_roi_feats = \ pos_avg_support_roi_feats + neg_avg_support_roi_feats bbox_results = self._bbox_forward(query_roi_feats, avg_support_roi_feats) bbox_targets = self.bbox_head.get_targets(sampling_results, query_gt_bboxes, query_gt_labels, self.train_cfg) (labels, label_weights, bbox_targets, bbox_weights) = bbox_targets # flip negative pair labels num_sampler_per_image = sampling_results[0].bboxes.size(0) num_pos_pair_samples = batch_size * num_sampler_per_image labels[num_pos_pair_samples:] = 1 bbox_weights[num_pos_pair_samples:] = 0 loss_bbox = self.bbox_head.loss( bbox_results['cls_score'], bbox_results['bbox_pred'], query_rois, labels, label_weights, bbox_targets, bbox_weights, num_pos_pair_samples, sample_fractions=self.sample_fractions) bbox_results.update(loss_bbox=loss_bbox) return bbox_results
[docs] def simple_test(self, query_feats: List[Tensor], support_feat: Tensor, proposals: List[Tensor], query_img_metas: List[Dict], rescale: bool = False) -> List[List[np.ndarray]]: """Test without augmentation. Args: query_feats (list[Tensor]): List of query features, each item with shape (N, C, H, W). support_feat (Tensor): Support features with shape (N, C, H, W). proposals (list[Tensors]): list of region proposals. query_img_metas (list[dict]): list of query image info dict where each dict has: `img_shape`, `scale_factor`, `flip`, and may also contain `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. proposals (Tensor or list[Tensor]): Region proposals. Default: None. rescale (bool): Whether to rescale the results. Default: False. Returns: list[list[np.ndarray]]: BBox results of each image and classes. The outer list corresponds to each image. The inner list corresponds to each class. """ assert self.with_bbox, 'Bbox head must be implemented.' det_bboxes, det_labels = self.simple_test_bboxes( query_feats, support_feat, query_img_metas, proposals, self.test_cfg, rescale=rescale) bbox_results = [ bbox2result(det_bboxes[i], det_labels[i], self.bbox_head.num_classes) for i in range(len(det_bboxes)) ] return bbox_results
[docs] def simple_test_bboxes( self, query_feats: List[Tensor], support_feat: Tensor, query_img_metas: List[Dict], proposals: List[Tensor], rcnn_test_cfg: ConfigDict, rescale: bool = False) -> Tuple[List[Tensor], List[Tensor]]: """Test only det bboxes without augmentation. Args: query_feats (list[Tensor]): List of query features, each item with shape (N, C, H, W). support_feat (Tensor): Support feature with shape (N, C, H, W). query_img_metas (list[dict]): list of image info dict where each dict has: `img_shape`, `scale_factor`, `flip`, and may also contain `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. For details on the values of these keys see :class:`mmdet.datasets.pipelines.Collect`. proposals (list[Tensor]): Region proposals. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. rescale (bool): If True, return boxes in original image space. Default: False. Returns: tuple[Tensor, Tensor]: BBox of shape [N, num_bboxes, 5] and class labels of shape [N, num_bboxes]. """ rois = bbox2roi(proposals) query_roi_feats = self.extract_roi_feat(query_feats, rois) bbox_results = self._bbox_forward(query_roi_feats, [support_feat]) img_shapes = tuple(meta['img_shape'] for meta in query_img_metas) scale_factors = tuple(meta['scale_factor'] for meta in query_img_metas) # split batch bbox prediction back to each image cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] num_proposals_per_img = tuple(len(p) for p in proposals) rois = rois.split(num_proposals_per_img, 0) cls_score = cls_score.split(num_proposals_per_img, 0) bbox_pred = bbox_pred.split(num_proposals_per_img, 0) # apply bbox post-processing to each image individually det_bboxes = [] det_labels = [] for i in range(len(proposals)): det_bbox, det_label = self.bbox_head.get_bboxes( rois[i], cls_score[i], bbox_pred[i], img_shapes[i], scale_factors[i], rescale=rescale, cfg=rcnn_test_cfg) det_bboxes.append(det_bbox) det_labels.append(det_label) return det_bboxes, det_labels
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.