Shortcuts

Source code for mmfewshot.classification.datasets.mini_imagenet

# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from typing import List, Optional, Sequence, Union

import mmcv
import numpy as np
from mmcls.datasets.builder import DATASETS
from typing_extensions import Literal

from .base import BaseFewShotDataset

TRAIN_CLASSES = [
    'n02074367', 'n03047690', 'n03854065', 'n02089867', 'n02105505',
    'n01704323', 'n04604644', 'n03676483', 'n01558993', 'n07697537',
    'n04509417', 'n02101006', 'n02165456', 'n13133613', 'n02747177',
    'n02966193', 'n03924679', 'n04275548', 'n02113712', 'n03838899',
    'n02091831', 'n03220513', 'n07747607', 'n03998194', 'n02108089',
    'n09246464', 'n04251144', 'n02111277', 'n04435653', 'n03207743',
    'n04389033', 'n03337140', 'n03908618', 'n02606052', 'n01770081',
    'n01910747', 'n03062245', 'n02108551', 'n03017168', 'n04258138',
    'n03888605', 'n04596742', 'n07584110', 'n02687172', 'n03476684',
    'n04243546', 'n02795169', 'n02457408', 'n04515003', 'n06794110',
    'n01532829', 'n01843383', 'n13054560', 'n04067472', 'n03347037',
    'n04612504', 'n03400231', 'n01749939', 'n02823428', 'n04296562',
    'n03527444', 'n04443257', 'n02108915', 'n02120079'
]
VAL_CLASSES = [
    'n02138441', 'n02981792', 'n02174001', 'n03535780', 'n03770439',
    'n03773504', 'n02950826', 'n03980874', 'n02114548', 'n03584254',
    'n02091244', 'n03417042', 'n02971356', 'n01855672', 'n09256479',
    'n03075370'
]
TEST_CLASSES = [
    'n02110341', 'n01981276', 'n07613480', 'n02129165', 'n04418357',
    'n02443484', 'n03127925', 'n01930112', 'n03272010', 'n03146219',
    'n04146614', 'n03775546', 'n04522168', 'n02099601', 'n02871525',
    'n02110063', 'n02219486', 'n02116738', 'n04149813', 'n03544143'
]


[docs]@DATASETS.register_module() class MiniImageNetDataset(BaseFewShotDataset): """MiniImageNet dataset for few shot classification. Args: subset (str| list[str]): The classes of whole dataset are split into three disjoint subset: train, val and test. If subset is a string, only one subset data will be loaded. If subset is a list of string, then all data of subset in list will be loaded. Options: ['train', 'val', 'test']. Default: 'train'. file_format (str): The file format of the image. Default: 'JPEG' """ resource = 'https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet' # noqa TRAIN_CLASSES = TRAIN_CLASSES VAL_CLASSES = VAL_CLASSES TEST_CLASSES = TEST_CLASSES def __init__(self, subset: Literal['train', 'test', 'val'] = 'train', file_format: str = 'JPEG', *args, **kwargs): if isinstance(subset, str): subset = [subset] for subset_ in subset: assert subset_ in ['train', 'test', 'val'] self.subset = subset self.file_format = file_format super().__init__(*args, **kwargs)
[docs] def get_classes( self, classes: Optional[Union[Sequence[str], str]] = None) -> Sequence[str]: """Get class names of current dataset. Args: classes (Sequence[str] | str | None): Three types of input will correspond to different processing logics: - If `classes` is a tuple or list, it will override the CLASSES predefined in the dataset. - If `classes` is None, we directly use pre-defined CLASSES will be used by the dataset. - If `classes` is a string, it is the path of a classes file that contains the name of all classes. Each line of the file contains a single class name. Returns: tuple[str] or list[str]: Names of categories of the dataset. """ if classes is None: class_names = [] for subset_ in self.subset: if subset_ == 'train': class_names += self.TRAIN_CLASSES elif subset_ == 'val': class_names += self.VAL_CLASSES elif subset_ == 'test': class_names += self.TEST_CLASSES else: raise ValueError(f'invalid subset {subset_} only ' f'support train, val or test.') elif isinstance(classes, str): # take it as a file path class_names = mmcv.list_from_file(classes) elif isinstance(classes, (tuple, list)): class_names = classes else: raise ValueError(f'Unsupported type {type(classes)} of classes.') return class_names
[docs] def load_annotations(self) -> List: """Load annotation according to the classes subset.""" img_file_list = { class_name: sorted( os.listdir(osp.join(self.data_prefix, 'images', class_name)), key=lambda x: int(x.split('.')[0].split('_')[-1])) for class_name in self.CLASSES } data_infos = [] for subset_ in self.subset: ann_file = osp.join(self.data_prefix, f'{subset_}.csv') assert osp.exists(ann_file), \ f'Please download ann_file through {self.resource}.' with open(ann_file, 'r') as f: for i, line in enumerate(f): # skip file head if i == 0: continue filename, class_name = line.strip().split(',') filename = img_file_list[class_name][ int(filename.split('.')[0][-5:]) - 1] gt_label = self.class_to_idx[class_name] info = { 'img_prefix': osp.join(self.data_prefix, 'images', class_name), 'img_info': { 'filename': filename }, 'gt_label': np.array(gt_label, dtype=np.int64) } data_infos.append(info) return data_infos
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.