Shortcuts

Source code for mmfewshot.classification.datasets.dataset_wrappers

# Copyright (c) OpenMMLab. All rights reserved.
from __future__ import annotations
import os.path as osp
from typing import Dict, List, Mapping, Optional, Tuple

import numpy as np
from torch import Tensor
from torch.utils.data import Dataset

from mmfewshot.utils import local_numpy_seed
from .builder import DATASETS


[docs]@DATASETS.register_module() class EpisodicDataset: """A wrapper of episodic dataset. It will generate a list of support and query images indices for each episode (support + query images). Every call of `__getitem__` will fetch and return (`num_ways` * `num_shots`) support images and (`num_ways` * `num_queries`) query images according to the generated images indices. Note that all the episode indices are generated at once using a specific random seed to ensure the reproducibility for same dataset. Args: dataset (:obj:`Dataset`): The dataset to be wrapped. num_episodes (int): Number of episodes. Noted that all episodes are generated at once and will not be changed afterwards. Make sure setting the `num_episodes` larger than your needs. num_ways (int): Number of ways for each episode. num_shots (int): Number of support data of each way for each episode. num_queries (int): Number of query data of each way for each episode. episodes_seed (int | None): A random seed to reproduce episodic indices. If seed is None, it will use runtime random seed. Default: None. """ def __init__(self, dataset: Dataset, num_episodes: int, num_ways: int, num_shots: int, num_queries: int, episodes_seed: Optional[int] = None) -> None: self.dataset = dataset self.num_ways = num_ways self.num_shots = num_shots self.num_queries = num_queries self.num_episodes = num_episodes self._len = len(self.dataset) self.CLASSES = dataset.CLASSES # using same episodes seed can generate same episodes for same dataset # it is designed for the reproducibility of meta train or meta test self.episodes_seed = episodes_seed self.episode_idxes, self.episode_class_ids = \ self.generate_episodic_idxes()
[docs] def generate_episodic_idxes(self) -> Tuple[List[Mapping], List[List[int]]]: """Generate batch indices for each episodic.""" episode_idxes, episode_class_ids = [], [] class_ids = [i for i in range(len(self.CLASSES))] # using same episodes seed can generate same episodes for same dataset # it is designed for the reproducibility of meta train or meta test with local_numpy_seed(self.episodes_seed): for _ in range(self.num_episodes): np.random.shuffle(class_ids) # sample classes sampled_cls = class_ids[:self.num_ways] episode_class_ids.append(sampled_cls) episodic_support_idx = [] episodic_query_idx = [] # sample instances of each class for i in range(self.num_ways): shots = self.dataset.sample_shots_by_class_id( sampled_cls[i], self.num_shots + self.num_queries) episodic_support_idx += shots[:self.num_shots] episodic_query_idx += shots[self.num_shots:] episode_idxes.append({ 'support': episodic_support_idx, 'query': episodic_query_idx }) return episode_idxes, episode_class_ids
def __getitem__(self, idx: int) -> Dict: """Return a episode data at the same time. For `EpisodicDataset`, this function would return num_ways * num_shots support images and num_ways * num_queries query image. """ return { 'support_data': [self.dataset[i] for i in self.episode_idxes[idx]['support']], 'query_data': [self.dataset[i] for i in self.episode_idxes[idx]['query']] } def __len__(self) -> int: """The length of the dataset is the number of generated episodes.""" return self.num_episodes
[docs] def evaluate(self, *args, **kwargs) -> List: """Evaluate prediction.""" return self.dataset.evaluate(*args, **kwargs)
[docs] def get_episode_class_ids(self, idx: int) -> List[int]: """Return class ids in one episode.""" return self.episode_class_ids[idx]
[docs]@DATASETS.register_module() class MetaTestDataset(EpisodicDataset): """A wrapper of the episodic dataset for meta testing. During meta test, the `MetaTestDataset` will be copied and converted into three mode: `test_set`, `support`, and `test`. Each mode of dataset will be used in different dataloader, but they share the same episode and image information. - In `test_set` mode, the dataset will fetch all images from the whole test set to extract features from the fixed backbone, which can accelerate meta testing. - In `support` or `query` mode, the dataset will fetch images according to the `episode_idxes` with the same `task_id`. Therefore, the support and query dataset must be set to the same `task_id` in each test task. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._mode = 'test_set' self._task_id = 0 self._with_cache_feats = False def with_cache_feats(self) -> bool: return self._with_cache_feats
[docs] def set_task_id(self, task_id: int) -> None: """Query and support dataset use same task id to make sure fetch data from same episode.""" self._task_id = task_id
def __getitem__(self, idx: int) -> Dict: """Return data according to mode. For mode `test_set`, this function would return single image as regular dataset. For mode `support`, this function would return single support image of current episode. For mode `query`, this function would return single query image of current episode. If the dataset have cached the extracted features from fixed backbone, then the features will be return instead of image. """ if self._mode == 'test_set': idx = idx elif self._mode == 'support': idx = self.episode_idxes[self._task_id]['support'][idx] elif self._mode == 'query': idx = self.episode_idxes[self._task_id]['query'][idx] if self._with_cache_feats: return { 'feats': self.dataset.data_infos[idx]['feats'], 'gt_label': self.dataset.data_infos[idx]['gt_label'] } else: return self.dataset[idx] def get_task_class_ids(self) -> List[int]: return self.get_episode_class_ids(self._task_id) def test_set(self) -> MetaTestDataset: self._mode = 'test_set' return self def support(self) -> MetaTestDataset: self._mode = 'support' return self def query(self) -> MetaTestDataset: self._mode = 'query' return self def __len__(self) -> int: if self._mode == 'test_set': return len(self.dataset) elif self._mode == 'support': return self.num_ways * self.num_shots elif self._mode == 'query': return self.num_ways * self.num_queries
[docs] def cache_feats(self, feats: Tensor, img_metas: Dict) -> None: """Cache extracted feats into dataset.""" idx_map = { osp.join(data_info['img_prefix'], data_info['img_info']['filename']): idx for idx, data_info in enumerate(self.dataset.data_infos) } # use filename as unique id for feat, img_meta in zip(feats, img_metas): idx = idx_map[img_meta['filename']] self.dataset.data_infos[idx]['feats'] = feat self._with_cache_feats = True
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.