Shortcuts

Source code for mmfewshot.detection.datasets.dataset_wrappers

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import json
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from mmdet.datasets.builder import DATASETS

from .base import BaseFewShotDataset


[docs]@DATASETS.register_module() class QueryAwareDataset: """A wrapper of QueryAwareDataset. Building QueryAwareDataset requires query and support dataset. Every call of `__getitem__` will firstly sample a query image and its annotations. Then it will use the query annotations to sample a batch of positive and negative support images and annotations. The positive images share same classes with query, while the annotations of negative images don't have any category from query. Args: query_dataset (:obj:`BaseFewShotDataset`): Query dataset to be wrapped. support_dataset (:obj:`BaseFewShotDataset` | None): Support dataset to be wrapped. If support dataset is None, support dataset will copy from query dataset. num_support_ways (int): Number of classes for support in mini-batch, the first one always be the positive class. num_support_shots (int): Number of support shots for each class in mini-batch, the first K shots always from positive class. repeat_times (int): The length of repeated dataset will be `times` larger than the original dataset. Default: 1. """ def __init__(self, query_dataset: BaseFewShotDataset, support_dataset: Optional[BaseFewShotDataset], num_support_ways: int, num_support_shots: int, repeat_times: int = 1) -> None: self.query_dataset = query_dataset if support_dataset is None: self.support_dataset = self.query_dataset else: self.support_dataset = support_dataset self.num_support_ways = num_support_ways self.num_support_shots = num_support_shots self.CLASSES = self.query_dataset.CLASSES self.repeat_times = repeat_times assert self.num_support_ways <= len( self.CLASSES ), 'Please set `num_support_ways` smaller than the number of classes.' # build data index (idx, gt_idx) by class. self.data_infos_by_class = {i: [] for i in range(len(self.CLASSES))} # counting max number of annotations in one image for each class, # which will decide whether sample repeated instance or not. self.max_anns_num_one_image = [0 for _ in range(len(self.CLASSES))] # count image for each class annotation when novel class only # has one image, the positive support is allowed sampled from itself. self.num_image_by_class = [0 for _ in range(len(self.CLASSES))] for idx in range(len(self.support_dataset)): labels = self.support_dataset.get_ann_info(idx)['labels'] class_count = [0 for _ in range(len(self.CLASSES))] for gt_idx, gt in enumerate(labels): self.data_infos_by_class[gt].append((idx, gt_idx)) class_count[gt] += 1 for i in range(len(self.CLASSES)): # number of images for each class if class_count[i] > 0: self.num_image_by_class[i] += 1 # max number of one class annotations in one image if class_count[i] > self.max_anns_num_one_image[i]: self.max_anns_num_one_image[i] = class_count[i] for i in range(len(self.CLASSES)): assert len(self.data_infos_by_class[i] ) > 0, f'Class {self.CLASSES[i]} has zero annotation' if len( self.data_infos_by_class[i] ) <= self.num_support_shots - self.max_anns_num_one_image[i]: warnings.warn( f'During training, instances of class {self.CLASSES[i]} ' f'may smaller than the number of support shots which ' f'causes some instance will be sampled multiple times') if self.num_image_by_class[i] == 1: warnings.warn(f'Class {self.CLASSES[i]} only have one ' f'image, query and support will sample ' f'from instance of same image') # Disable the group sampler, because in few shot setting, # one group may only has two or three images. if hasattr(self.query_dataset, 'flag'): self.flag = np.zeros( len(self.query_dataset) * self.repeat_times, dtype=np.uint8) self._ori_len = len(self.query_dataset) def __getitem__(self, idx: int) -> Dict: """Return query image and support images at the same time. For query aware dataset, this function would return one query image and num_support_ways * num_support_shots support images. The support images are sampled according to the selected query image. There should be no intersection between the classes of instances in query data and in support data. Args: idx (int): the index of data. Returns: dict: A dict contains query data and support data, it usually contains two fields. - query_data: A dict of single query data information. - support_data: A list of dict, has num_support_ways * num_support_shots support images and corresponding annotations. """ idx %= self._ori_len # sample query data try_time = 0 while True: try_time += 1 cat_ids = self.query_dataset.get_cat_ids(idx) # query image have too many classes, can not find enough # negative support classes. if len(self.CLASSES) - len(cat_ids) >= self.num_support_ways - 1: break else: idx = self._rand_another(idx) % self._ori_len assert try_time < 100, \ 'Not enough negative support classes for ' \ 'query image, please try a smaller support way.' query_class = np.random.choice(cat_ids) query_gt_idx = [ i for i in range(len(cat_ids)) if cat_ids[i] == query_class ] query_data = self.query_dataset.prepare_train_img( idx, 'query', query_gt_idx) query_data['query_class'] = [query_class] # sample negative support classes, which not appear in query image support_class = [ i for i in range(len(self.CLASSES)) if i not in cat_ids ] support_class = np.random.choice( support_class, min(self.num_support_ways - 1, len(support_class)), replace=False) support_idxes = self.generate_support(idx, query_class, support_class) support_data = [ self.support_dataset.prepare_train_img(idx, 'support', [gt_idx]) for (idx, gt_idx) in support_idxes ] return {'query_data': query_data, 'support_data': support_data} def __len__(self) -> int: """Length after repetition.""" return len(self.query_dataset) * self.repeat_times def _rand_another(self, idx: int) -> int: """Get another random index from the same group as the given index.""" pool = np.where(self.flag == self.flag[idx])[0] return np.random.choice(pool)
[docs] def generate_support(self, idx: int, query_class: int, support_classes: List[int]) -> List[Tuple[int]]: """Generate support indices of query images. Args: idx (int): Index of query data. query_class (int): Query class. support_classes (list[int]): Classes of support data. Returns: list[tuple(int)]: A mini-batch (num_support_ways * num_support_shots) of support data (idx, gt_idx). """ support_idxes = [] if self.num_image_by_class[query_class] == 1: # only have one image, instance will sample from same image pos_support_idxes = self.sample_support_shots( idx, query_class, allow_same_image=True) else: # instance will sample from different image from query image pos_support_idxes = self.sample_support_shots(idx, query_class) support_idxes.extend(pos_support_idxes) for support_class in support_classes: neg_support_idxes = self.sample_support_shots(idx, support_class) support_idxes.extend(neg_support_idxes) return support_idxes
[docs] def sample_support_shots( self, idx: int, class_id: int, allow_same_image: bool = False) -> List[Tuple[int]]: """Generate support indices according to the class id. Args: idx (int): Index of query data. class_id (int): Support class. allow_same_image (bool): Allow instance sampled from same image as query image. Default: False. Returns: list[tuple[int]]: Support data (num_support_shots) of specific class. """ support_idxes = [] num_total_shots = len(self.data_infos_by_class[class_id]) # count number of support instance in query image cat_ids = self.support_dataset.get_cat_ids(idx % self._ori_len) num_ignore_shots = len([1 for cat_id in cat_ids if cat_id == class_id]) # set num_sample_shots for each time of sampling if num_total_shots - num_ignore_shots < self.num_support_shots: # if not have enough support data allow repeated data num_sample_shots = num_total_shots allow_repeat = True else: # if have enough support data not allow repeated data num_sample_shots = self.num_support_shots allow_repeat = False while len(support_idxes) < self.num_support_shots: selected_gt_idxes = np.random.choice( num_total_shots, num_sample_shots, replace=False) selected_gts = [ self.data_infos_by_class[class_id][selected_gt_idx] for selected_gt_idx in selected_gt_idxes ] for selected_gt in selected_gts: # filter out query annotations if selected_gt[0] == idx: if not allow_same_image: continue if allow_repeat: support_idxes.append(selected_gt) elif selected_gt not in support_idxes: support_idxes.append(selected_gt) if len(support_idxes) == self.num_support_shots: break # update the number of data for next time sample num_sample_shots = min(self.num_support_shots - len(support_idxes), num_sample_shots) return support_idxes
[docs] def save_data_infos(self, output_path: str) -> None: """Save data_infos into json.""" self.query_dataset.save_data_infos(output_path) # for query aware datasets support and query set use same data paths = output_path.split('.') self.support_dataset.save_data_infos( '.'.join(paths[:-1] + ['support_shot', paths[-1]]))
[docs] def get_support_data_infos(self) -> List[Dict]: """Return data_infos of support dataset.""" return copy.deepcopy(self.support_dataset.data_infos)
[docs]@DATASETS.register_module() class NWayKShotDataset(object): """A dataset wrapper of NWayKShotDataset. Building NWayKShotDataset requires query and support dataset, the behavior of NWayKShotDataset is determined by `mode`. When dataset in 'query' mode, dataset will return regular image and annotations. While dataset in 'support' mode, dataset will build batch indices firstly and each batch indices contain (num_support_ways * num_support_shots) samples. In other words, for support mode every call of `__getitem__` will return a batch of samples, therefore the outside dataloader should set batch_size to 1. The default `mode` of NWayKShotDataset is 'query' and by using convert function `convert_query_to_support` the `mode` will be converted into 'support'. Args: query_dataset (:obj:`BaseFewShotDataset`): Query dataset to be wrapped. support_dataset (:obj:`BaseFewShotDataset` | None): Support dataset to be wrapped. If support dataset is None, support dataset will copy from query dataset. num_support_ways (int): Number of classes for support in mini-batch. num_support_shots (int): Number of support shot for each class in mini-batch. one_support_shot_per_image (bool): If True only one annotation will be sampled from each image. Default: False. num_used_support_shots (int | None): The total number of support shots sampled and used for each class during training. If set to None, all shots in dataset will be used as support shot. Default: 200. shuffle_support (bool): If allow generate new batch indices for each epoch. Default: False. repeat_times (int): The length of repeated dataset will be `times` larger than the original dataset. Default: 1. """ def __init__(self, query_dataset: BaseFewShotDataset, support_dataset: Optional[BaseFewShotDataset], num_support_ways: int, num_support_shots: int, one_support_shot_per_image: bool = False, num_used_support_shots: int = 200, repeat_times: int = 1) -> None: self.query_dataset = query_dataset if support_dataset is None: self.support_dataset = self.query_dataset else: self.support_dataset = support_dataset self.CLASSES = self.query_dataset.CLASSES # The mode determinate the behavior of fetching data, # the default mode is 'query'. To convert the dataset # into 'support' dataset, simply call the function # convert_query_to_support(). self._mode = 'query' self.num_support_ways = num_support_ways self.one_support_shot_per_image = one_support_shot_per_image self.num_used_support_shots = num_used_support_shots assert num_support_ways <= len( self.CLASSES ), 'support way can not larger than the number of classes' self.num_support_shots = num_support_shots self.batch_indices = [] self.data_infos_by_class = {i: [] for i in range(len(self.CLASSES))} self.prepare_support_shots() self.repeat_times = repeat_times # Disable the group sampler, because in few shot setting, # one group may only has two or three images. if hasattr(query_dataset, 'flag'): self.flag = np.zeros( len(self.query_dataset) * self.repeat_times, dtype=np.uint8) self._ori_len = len(self.query_dataset) def __getitem__(self, idx: int) -> Union[Dict, List[Dict]]: if self._mode == 'query': idx %= self._ori_len # loads one data in query pipeline return self.query_dataset.prepare_train_img(idx, 'query') elif self._mode == 'support': # loads one batch of data in support pipeline b_idx = self.batch_indices[idx] batch_data = [ self.support_dataset.prepare_train_img(idx, 'support', [gt_idx]) for (idx, gt_idx) in b_idx ] return batch_data else: raise ValueError('not valid data type') def __len__(self) -> int: """Length of dataset.""" if self._mode == 'query': return len(self.query_dataset) * self.repeat_times elif self._mode == 'support': return len(self.batch_indices) else: raise ValueError(f'{self._mode}not a valid mode') def prepare_support_shots(self) -> None: # create lookup table for annotations in same class. for idx in range(len(self.support_dataset)): labels = self.support_dataset.get_ann_info(idx)['labels'] for gt_idx, gt in enumerate(labels): # When the number of support shots reaches # `num_used_support_shots`, the class will be skipped if len(self.data_infos_by_class[gt]) < \ self.num_used_support_shots: self.data_infos_by_class[gt].append((idx, gt_idx)) # When `one_support_shot_per_image` is true, only one # annotation will be sampled for each image. if self.one_support_shot_per_image: break # make sure all class index lists have enough # instances (length > num_support_shots) for i in range(len(self.CLASSES)): num_gts = len(self.data_infos_by_class[i]) if num_gts < self.num_support_shots: self.data_infos_by_class[i] = self.data_infos_by_class[i] * ( self.num_support_shots // num_gts + 1)
[docs] def convert_query_to_support(self, support_dataset_len: int) -> None: """Convert query dataset to support dataset. Args: support_dataset_len (int): Length of pre sample batch indices. """ self.batch_indices = \ self.generate_support_batch_indices(support_dataset_len) self._mode = 'support' if hasattr(self, 'flag'): self.flag = np.zeros(support_dataset_len, dtype=np.uint8)
[docs] def generate_support_batch_indices( self, dataset_len: int) -> List[List[Tuple[int]]]: """Generate batch indices from support dataset. Batch indices is in the shape of [length of datasets * [support way * support shots]]. And the `dataset_len` will be the length of support dataset. Args: dataset_len (int): Length of batch indices. Returns: list[list[(data_idx, gt_idx)]]: Pre-sample batch indices. """ total_indices = [] for _ in range(dataset_len): batch_indices = [] selected_classes = np.random.choice( len(self.CLASSES), self.num_support_ways, replace=False) for cls in selected_classes: num_gts = len(self.data_infos_by_class[cls]) selected_gts_idx = np.random.choice( num_gts, self.num_support_shots, replace=False) selected_gts = [ self.data_infos_by_class[cls][gt_idx] for gt_idx in selected_gts_idx ] batch_indices.extend(selected_gts) total_indices.append(batch_indices) return total_indices
[docs] def save_data_infos(self, output_path: str) -> None: """Save data infos of query and support data.""" self.query_dataset.save_data_infos(output_path) paths = output_path.split('.') self.save_support_data_infos('.'.join(paths[:-1] + ['support_shot', paths[-1]]))
[docs] def save_support_data_infos(self, support_output_path: str) -> None: """Save support data infos.""" support_data_infos = self.get_support_data_infos() meta_info = [{ 'CLASSES': self.CLASSES, 'img_prefix': self.support_dataset.img_prefix }] from .utils import NumpyEncoder with open(support_output_path, 'w', encoding='utf-8') as f: json.dump( meta_info + support_data_infos, f, ensure_ascii=False, indent=4, cls=NumpyEncoder)
[docs] def get_support_data_infos(self) -> List[Dict]: """Get support data infos from batch indices.""" return copy.deepcopy([ self._get_shot_data_info(idx, gt_idx) for class_name in self.data_infos_by_class.keys() for (idx, gt_idx) in self.data_infos_by_class[class_name] ])
def _get_shot_data_info(self, idx: int, gt_idx: [int]) -> Dict: """Get data info by idx and gt idx.""" data_info = copy.deepcopy(self.support_dataset.data_infos[idx]) data_info['ann']['labels'] = \ data_info['ann']['labels'][gt_idx:gt_idx + 1] data_info['ann']['bboxes'] = \ data_info['ann']['bboxes'][gt_idx:gt_idx + 1] return data_info
@DATASETS.register_module() class TwoBranchDataset: """A dataset wrapper of TwoBranchDataset. Wrapping main_dataset and auxiliary_dataset to a single dataset and thus building TwoBranchDataset requires two dataset. The behavior of TwoBranchDataset is determined by `mode`. Dataset will return images and annotations according to `mode`, e.g. fetching data from main_dataset if `mode` is 'main'. The default `mode` is 'main' and by using convert function `convert_main_to_auxiliary` the `mode` will be converted into 'auxiliary'. Args: main_dataset (:obj:`BaseFewShotDataset`): Main dataset to be wrapped. auxiliary_dataset (:obj:`BaseFewShotDataset` | None): Auxiliary dataset to be wrapped. If auxiliary dataset is None, auxiliary dataset will copy from main dataset. reweight_dataset (bool): Whether to change the sampling weights of VOC07 and VOC12 . Default: False. """ def __init__(self, main_dataset: BaseFewShotDataset = None, auxiliary_dataset: Optional[BaseFewShotDataset] = None, reweight_dataset: bool = False) -> None: assert main_dataset and auxiliary_dataset self._mode = 'main' self.main_dataset = main_dataset self.auxiliary_dataset = auxiliary_dataset self.CLASSES = self.main_dataset.CLASSES if reweight_dataset: # Reweight the VOC dataset to be consistent with the original # implementation of MPSR. For more details, please refer to # https://github.com/jiaxi-wu/MPSR/blob/master/maskrcnn_benchmark/data/datasets/voc.py#L137 self.main_idx_map = self.reweight_dataset( self.main_dataset, ['VOC2007', 'VOC2012'], ) self.auxiliary_idx_map = self.reweight_dataset( self.auxiliary_dataset, ['VOC']) else: self.main_idx_map = list(range(len(self.main_dataset))) self.auxiliary_idx_map = list(range(len(self.auxiliary_dataset))) self._main_len = len(self.main_idx_map) self._auxiliary_len = len(self.auxiliary_idx_map) self._set_group_flag() def __getitem__(self, idx: int) -> Dict: if self._mode == 'main': idx %= self._main_len idx = self.main_idx_map[idx] return self.main_dataset.prepare_train_img(idx, 'main') elif self._mode == 'auxiliary': idx %= self._auxiliary_len idx = self.auxiliary_idx_map[idx] return self.auxiliary_dataset.prepare_train_img(idx, 'auxiliary') else: raise ValueError('not valid data type') def __len__(self) -> int: """Length of dataset.""" if self._mode == 'main': return self._main_len elif self._mode == 'auxiliary': return self._auxiliary_len else: raise ValueError('not valid data type') def convert_main_to_auxiliary(self) -> None: """Convert main dataset to auxiliary dataset.""" self._mode = 'auxiliary' self._set_group_flag() def save_data_infos(self, output_path: str) -> None: """Save data infos of main and auxiliary data.""" self.main_dataset.save_data_infos(output_path) paths = output_path.split('.') self.auxiliary_dataset.save_data_infos( '.'.join(paths[:-1] + ['auxiliary', paths[-1]])) def _set_group_flag(self) -> None: # disable the group sampler, because in few shot setting, # one group may only has two or three images. self.flag = np.zeros(len(self), dtype=np.uint8) @staticmethod def reweight_dataset(dataset: BaseFewShotDataset, group_prefix: Sequence[str], repeat_length: int = 100) -> List: """Reweight the dataset.""" groups = [[] for _ in range(len(group_prefix))] for i in range(len(dataset)): filename = dataset.data_infos[i]['filename'] for j, prefix in enumerate(group_prefix): if prefix in filename: groups[j].append(i) break assert j < len(group_prefix) - 1 # Reweight the dataset to be consistent with the original # implementation of MPSR. For more details, please refer to # https://github.com/jiaxi-wu/MPSR/blob/master/maskrcnn_benchmark/data/datasets/voc.py#L137 reweight_idx_map = [] for g in groups: if len(g) < 50: reweight_idx_map += g * (int(repeat_length / len(g)) + 1) else: reweight_idx_map += g return reweight_idx_map
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.