Shortcuts

mmfewshot.classification

classification.apis

mmfewshot.classification.apis.inference_classifier(model: torch.nn.modules.module.Module, query_img: str)Dict[源代码]

Inference single image with the classifier.

参数
  • model (nn.Module) – The loaded classifier.

  • query_img (str) – The image filename.

返回

The classification results that contains

pred_score of each class.

返回类型

dict

mmfewshot.classification.apis.init_classifier(config: Union[str, mmcv.utils.config.Config], checkpoint: Optional[str] = None, device: str = 'cuda:0', options: Optional[Dict] = None)torch.nn.modules.module.Module[源代码]

Prepare a few shot classifier from config file.

参数
  • config (str or mmcv.Config) – Config file path or the config object.

  • checkpoint (str | None) – Checkpoint path. If left as None, the model will not load any weights. Default: None.

  • device (str) – Runtime device. Default: ‘cuda:0’.

  • options (dict | None) – Options to override some settings in the used config. Default: None.

返回

The constructed classifier.

返回类型

nn.Module

mmfewshot.classification.apis.multi_gpu_meta_test(model: mmcv.parallel.distributed.MMDistributedDataParallel, num_test_tasks: int, support_dataloader: torch.utils.data.dataloader.DataLoader, query_dataloader: torch.utils.data.dataloader.DataLoader, test_set_dataloader: Optional[torch.utils.data.dataloader.DataLoader] = None, meta_test_cfg: Optional[Dict] = None, eval_kwargs: Optional[Dict] = None, logger: Optional[object] = None, confidence_interval: float = 0.95, show_task_results: bool = False)Dict[源代码]

Distributed meta testing on multiple gpus.

During meta testing, model might be further fine-tuned or added extra parameters. While the tested model need to be restored after meta testing since meta testing can be used as the validation in the middle of training. To detach model from previous phase, the model will be copied and wrapped with MetaTestParallel. And it has full independence from the training model and will be discarded after the meta testing.

In the distributed situation, the MetaTestParallel on each GPU is also independent. The test tasks in few shot leaning usually are very small and hardly benefit from distributed acceleration. Thus, in distributed meta testing, each task is done in single GPU and each GPU is assigned a certain number of tasks. The number of test tasks for each GPU is ceil(num_test_tasks / world_size). After all GPUs finish their tasks, the results will be aggregated to get the final result.

参数
  • model (MMDistributedDataParallel) – Model to be meta tested.

  • num_test_tasks (int) – Number of meta testing tasks.

  • support_dataloader (DataLoader) – A PyTorch dataloader of support data.

  • query_dataloader (DataLoader) – A PyTorch dataloader of query data.

  • test_set_dataloader (DataLoader) – A PyTorch dataloader of all test data. Default: None.

  • meta_test_cfg (dict) – Config for meta testing. Default: None.

  • eval_kwargs (dict) – Any keyword argument to be used for evaluation. Default: None.

  • logger (logging.Logger | None) – Logger used for printing related information during evaluation. Default: None.

  • confidence_interval (float) – Confidence interval. Default: 0.95.

  • show_task_results (bool) – Whether to record the eval result of each task. Default: False.

返回

Dict of meta evaluate results, containing accuracy_mean

and accuracy_std of all test tasks.

返回类型

dict | None

mmfewshot.classification.apis.process_support_images(model: torch.nn.modules.module.Module, support_imgs: List[str], support_labels: List[str])None[源代码]

Process support images.

参数
  • model (nn.Module) – Classifier model.

  • support_imgs (list[str]) – The image filenames.

  • support_labels (list[str]) – The class names of support images.

mmfewshot.classification.apis.show_result_pyplot(img: str, result: Dict, fig_size: Tuple[int] = (15, 10), wait_time: int = 0, out_file: Optional[str] = None)numpy.ndarray[源代码]

Visualize the classification results on the image.

参数
  • img (str) – Image filename.

  • result (dict) – The classification result.

  • fig_size (tuple) – Figure size of the pyplot figure. Default: (15, 10).

  • wait_time (int) – How many seconds to display the image. Default: 0.

  • out_file (str | None) – Default: None

返回

pyplot figure.

返回类型

np.ndarray

mmfewshot.classification.apis.single_gpu_meta_test(model: Union[mmcv.parallel.data_parallel.MMDataParallel, torch.nn.modules.module.Module], num_test_tasks: int, support_dataloader: torch.utils.data.dataloader.DataLoader, query_dataloader: torch.utils.data.dataloader.DataLoader, test_set_dataloader: Optional[torch.utils.data.dataloader.DataLoader] = None, meta_test_cfg: Optional[Dict] = None, eval_kwargs: Optional[Dict] = None, logger: Optional[object] = None, confidence_interval: float = 0.95, show_task_results: bool = False)Dict[源代码]

Meta testing on single gpu.

During meta testing, model might be further fine-tuned or added extra parameters. While the tested model need to be restored after meta testing since meta testing can be used as the validation in the middle of training. To detach model from previous phase, the model will be copied and wrapped with MetaTestParallel. And it has full independence from the training model and will be discarded after the meta testing.

参数
  • model (MMDataParallel | nn.Module) – Model to be meta tested.

  • num_test_tasks (int) – Number of meta testing tasks.

  • support_dataloader (DataLoader) – A PyTorch dataloader of support data and it is used to fetch support data for each task.

  • query_dataloader (DataLoader) – A PyTorch dataloader of query data and it is used to fetch query data for each task.

  • test_set_dataloader (DataLoader) – A PyTorch dataloader of all test data and it is used for feature extraction from whole dataset to accelerate the testing. Default: None.

  • meta_test_cfg (dict) – Config for meta testing. Default: None.

  • eval_kwargs (dict) – Any keyword argument to be used for evaluation. Default: None.

  • logger (logging.Logger | None) – Logger used for printing related information during evaluation. Default: None.

  • confidence_interval (float) – Confidence interval. Default: 0.95.

  • show_task_results (bool) – Whether to record the eval result of each task. Default: False.

返回

Dict of meta evaluate results, containing accuracy_mean

and accuracy_std of all test tasks.

返回类型

dict

mmfewshot.classification.apis.test_single_task(model: mmfewshot.classification.utils.meta_test_parallel.MetaTestParallel, support_dataloader: torch.utils.data.dataloader.DataLoader, query_dataloader: torch.utils.data.dataloader.DataLoader, meta_test_cfg: Dict)[源代码]

Test a single task.

A task has two stages: handling the support set and predicting the query set. In stage one, it currently supports fine-tune based and metric based methods. In stage two, it simply forward the query set and gather all the results.

参数
  • model (MetaTestParallel) – Model to be meta tested.

  • support_dataloader (DataLoader) – A PyTorch dataloader of support data.

  • query_dataloader (DataLoader) – A PyTorch dataloader of query data.

  • meta_test_cfg (dict) – Config for meta testing.

返回

  • results_list (list[np.ndarray]): Predict results.

  • gt_labels (np.ndarray): Ground truth labels.

返回类型

tuple

classification.core

evaluation

class mmfewshot.classification.core.evaluation.DistMetaTestEvalHook(support_dataloader: torch.utils.data.dataloader.DataLoader, query_dataloader: torch.utils.data.dataloader.DataLoader, test_set_dataloader: torch.utils.data.dataloader.DataLoader, num_test_tasks: int, interval: int = 1, by_epoch: bool = True, meta_test_cfg: Optional[Dict] = None, confidence_interval: float = 0.95, save_best: bool = True, key_indicator: str = 'accuracy_mean', **eval_kwargs)[源代码]

Distributed evaluation hook.

class mmfewshot.classification.core.evaluation.MetaTestEvalHook(support_dataloader: torch.utils.data.dataloader.DataLoader, query_dataloader: torch.utils.data.dataloader.DataLoader, test_set_dataloader: torch.utils.data.dataloader.DataLoader, num_test_tasks: int, interval: int = 1, by_epoch: bool = True, meta_test_cfg: Optional[Dict] = None, confidence_interval: float = 0.95, save_best: bool = True, key_indicator: str = 'accuracy_mean', **eval_kwargs)[源代码]

Evaluation hook for Meta Testing.

参数
  • support_dataloader (DataLoader) – A PyTorch dataloader of support data.

  • query_dataloader (DataLoader) – A PyTorch dataloader of query data.

  • test_set_dataloader (DataLoader) – A PyTorch dataloader of all test data.

  • num_test_tasks (int) – Number of tasks for meta testing.

  • interval (int) – Evaluation interval (by epochs or iteration). Default: 1.

  • by_epoch (bool) – Epoch based runner or not. Default: True.

  • meta_test_cfg (dict) – Config for meta testing.

  • confidence_interval (float) – Confidence interval. Default: 0.95.

  • save_best (bool) – Whether to save best validated model. Default: True.

  • key_indicator (str) – The validation metric for selecting the best model. Default: ‘accuracy_mean’.

  • eval_kwargs – Any keyword argument to be used for evaluation.

classification.datasets

class mmfewshot.classification.datasets.BaseFewShotDataset(data_prefix: str, pipeline: List[Dict], classes: Optional[Union[str, List[str]]] = None, ann_file: Optional[str] = None)[源代码]

Base few shot dataset.

参数
  • data_prefix (str) – The prefix of data path.

  • pipeline (list) – A list of dict, where each element represents a operation defined in mmcls.datasets.pipelines.

  • classes (str | Sequence[str] | None) – Classes for model training and provide fixed label for each class. Default: None.

  • ann_file (str | None) – The annotation file. When ann_file is str, the subclass is expected to read from the ann_file. When ann_file is None, the subclass is expected to read according to data_prefix. Default: None.

property class_to_idx: Mapping

Map mapping class name to class index.

返回

mapping from class name to class index.

返回类型

dict

static evaluate(results: List, gt_labels: numpy.array, metric: Union[str, List[str]] = 'accuracy', metric_options: Optional[dict] = None, logger: Optional[object] = None)Dict[源代码]

Evaluate the dataset.

参数
  • results (list) – Testing results of the dataset.

  • gt_labels (np.ndarray) – Ground truth labels.

  • metric (str | list[str]) – Metrics to be evaluated. Default value is accuracy.

  • metric_options (dict | None) – Options for calculating metrics. Allowed keys are ‘topk’, ‘thrs’ and ‘average_mode’. Default: None.

  • logger (logging.Logger | None) – Logger used for printing related information during evaluation. Default: None.

返回

evaluation results

返回类型

dict

classmethod get_classes(classes: Optional[Union[Sequence[str], str]] = None)Sequence[str][源代码]

Get class names of current dataset.

参数

classes (Sequence[str] | str | None) –

Three types of input will correspond to different processing logics:

  • If classes is a tuple or list, it will override the CLASSES predefined in the dataset.

  • If classes is None, we directly use pre-defined CLASSES will be used by the dataset.

  • If classes is a string, it is the path of a classes file that contains the name of all classes. Each line of the file contains a single class name.

返回

Names of categories of the dataset.

返回类型

tuple[str] or list[str]

sample_shots_by_class_id(class_id: int, num_shots: int)List[int][源代码]

Random sample shots of given class id.

class mmfewshot.classification.datasets.CUBDataset(classes_id_seed: Optional[int] = None, subset: typing_extensions.Literal[train, test, val] = 'train', *args, **kwargs)[源代码]

CUB dataset for few shot classification.

参数
  • classes_id_seed (int | None) – A random seed to shuffle order of classes. If seed is None, the classes will be arranged in alphabetical order. Default: None.

  • subset (str| list[str]) – The classes of whole dataset are split into three disjoint subset: train, val and test. If subset is a string, only one subset data will be loaded. If subset is a list of string, then all data of subset in list will be loaded. Options: [‘train’, ‘val’, ‘test’]. Default: ‘train’.

get_classes(classes: Optional[Union[Sequence[str], str]] = None)Sequence[str][源代码]

Get class names of current dataset.

参数

classes (Sequence[str] | str | None) –

Three types of input will correspond to different processing logics:

  • If classes is a tuple or list, it will override the CLASSES predefined in the dataset.

  • If classes is None, we directly use pre-defined CLASSES will be used by the dataset.

  • If classes is a string, it is the path of a classes file that contains the name of all classes. Each line of the file contains a single class name.

返回

Names of categories of the dataset.

返回类型

tuple[str] or list[str]

load_annotations()List[Dict][源代码]

Load annotation according to the classes subset.

class mmfewshot.classification.datasets.EpisodicDataset(dataset: Dataset, num_episodes: int, num_ways: int, num_shots: int, num_queries: int, episodes_seed: int | None = None)[源代码]

A wrapper of episodic dataset.

It will generate a list of support and query images indices for each episode (support + query images). Every call of __getitem__ will fetch and return (num_ways * num_shots) support images and (num_ways * num_queries) query images according to the generated images indices. Note that all the episode indices are generated at once using a specific random seed to ensure the reproducibility for same dataset.

参数
  • dataset (Dataset) – The dataset to be wrapped.

  • num_episodes (int) – Number of episodes. Noted that all episodes are generated at once and will not be changed afterwards. Make sure setting the num_episodes larger than your needs.

  • num_ways (int) – Number of ways for each episode.

  • num_shots (int) – Number of support data of each way for each episode.

  • num_queries (int) – Number of query data of each way for each episode.

  • episodes_seed (int | None) – A random seed to reproduce episodic indices. If seed is None, it will use runtime random seed. Default: None.

evaluate(*args, **kwargs)list[源代码]

Evaluate prediction.

generate_episodic_idxes()tuple[list[Mapping], list[list[int]]][源代码]

Generate batch indices for each episodic.

get_episode_class_ids(idx: int)list[int][源代码]

Return class ids in one episode.

class mmfewshot.classification.datasets.LoadImageFromBytes(to_float32=False, color_type='color', file_client_args={'backend': 'disk'})[源代码]

Load an image from bytes.

class mmfewshot.classification.datasets.MetaTestDataset(*args, **kwargs)[源代码]

A wrapper of the episodic dataset for meta testing.

During meta test, the MetaTestDataset will be copied and converted into three mode: test_set, support, and test. Each mode of dataset will be used in different dataloader, but they share the same episode and image information.

  • In test_set mode, the dataset will fetch all images from the whole test set to extract features from the fixed backbone, which can accelerate meta testing.

  • In support or query mode, the dataset will fetch images according to the episode_idxes with the same task_id. Therefore, the support and query dataset must be set to the same task_id in each test task.

cache_feats(feats: torch.Tensor, img_metas: dict)None[源代码]

Cache extracted feats into dataset.

set_task_id(task_id: int)None[源代码]

Query and support dataset use same task id to make sure fetch data from same episode.

class mmfewshot.classification.datasets.MiniImageNetDataset(subset: typing_extensions.Literal[train, test, val] = 'train', file_format: str = 'JPEG', *args, **kwargs)[源代码]

MiniImageNet dataset for few shot classification.

参数
  • subset (str| list[str]) – The classes of whole dataset are split into three disjoint subset: train, val and test. If subset is a string, only one subset data will be loaded. If subset is a list of string, then all data of subset in list will be loaded. Options: [‘train’, ‘val’, ‘test’]. Default: ‘train’.

  • file_format (str) – The file format of the image. Default: ‘JPEG’

get_classes(classes: Optional[Union[Sequence[str], str]] = None)Sequence[str][源代码]

Get class names of current dataset.

参数

classes (Sequence[str] | str | None) –

Three types of input will correspond to different processing logics:

  • If classes is a tuple or list, it will override the CLASSES predefined in the dataset.

  • If classes is None, we directly use pre-defined CLASSES will be used by the dataset.

  • If classes is a string, it is the path of a classes file that contains the name of all classes. Each line of the file contains a single class name.

返回

Names of categories of the dataset.

返回类型

tuple[str] or list[str]

load_annotations()List[源代码]

Load annotation according to the classes subset.

class mmfewshot.classification.datasets.TieredImageNetDataset(subset: typing_extensions.Literal[train, test, val] = 'train', *args, **kwargs)[源代码]

TieredImageNet dataset for few shot classification.

参数

subset (str| list[str]) – The classes of whole dataset are split into three disjoint subset: train, val and test. If subset is a string, only one subset data will be loaded. If subset is a list of string, then all data of subset in list will be loaded. Options: [‘train’, ‘val’, ‘test’]. Default: ‘train’.

get_classes(classes: Optional[Union[Sequence[str], str]] = None)Sequence[str][源代码]

Get class names of current dataset.

参数

classes (Sequence[str] | str | None) –

Three types of input will correspond to different processing logics:

  • If classes is a tuple or list, it will override the CLASSES predefined in the dataset.

  • If classes is None, we directly use pre-defined CLASSES will be used by the dataset.

  • If classes is a string, it is the path of a classes file that contains the name of all classes. Each line of the file contains a single class name.

返回

Names of categories of the dataset.

返回类型

tuple[str] or list[str]

get_general_classes()List[str][源代码]

Get general classes of each classes.

load_annotations()List[Dict][源代码]

Load annotation according to the classes subset.

mmfewshot.classification.datasets.build_dataloader(dataset: torch.utils.data.dataset.Dataset, samples_per_gpu: int, workers_per_gpu: int, num_gpus: int = 1, dist: bool = True, shuffle: bool = True, round_up: bool = True, seed: Optional[int] = None, pin_memory: bool = False, use_infinite_sampler: bool = False, **kwargs)torch.utils.data.dataloader.DataLoader[源代码]

Build PyTorch DataLoader.

In distributed training, each GPU/process has a dataloader. In non-distributed training, there is only one dataloader for all GPUs.

参数
  • dataset (Dataset) – A PyTorch dataset.

  • samples_per_gpu (int) – Number of training samples on each GPU, i.e., batch size of each GPU.

  • workers_per_gpu (int) – How many subprocesses to use for data loading for each GPU.

  • num_gpus (int) – Number of GPUs. Only used in non-distributed training.

  • dist (bool) – Distributed training/test or not. Default: True.

  • shuffle (bool) – Whether to shuffle the data at every epoch. Default: True.

  • round_up (bool) – Whether to round up the length of dataset by adding extra samples to make it evenly divisible. Default: True.

  • seed (int | None) – Random seed. Default:None.

  • pin_memory (bool) – Whether to use pin_memory for dataloader. Default: False.

  • use_infinite_sampler (bool) – Whether to use infinite sampler. Noted that infinite sampler will keep iterator of dataloader running forever, which can avoid the overhead of worker initialization between epochs. Default: False.

  • kwargs – any keyword argument to be used to initialize DataLoader

返回

A PyTorch dataloader.

返回类型

DataLoader

mmfewshot.classification.datasets.build_meta_test_dataloader(dataset: torch.utils.data.dataset.Dataset, meta_test_cfg: Dict, **kwargs)torch.utils.data.dataloader.DataLoader[源代码]

Build PyTorch DataLoader.

In distributed training, each GPU/process has a dataloader. In non-distributed training, there is only one dataloader for all GPUs.

参数
  • dataset (Dataset) – A PyTorch dataset.

  • meta_test_cfg (dict) – Config of meta testing.

  • kwargs – any keyword argument to be used to initialize DataLoader

返回

support_data_loader, query_data_loader

and test_set_data_loader.

返回类型

tuple[Dataloader]

mmfewshot.classification.datasets.label_wrapper(labels: Union[torch.Tensor, numpy.ndarray, List], class_ids: List[int])Union[torch.Tensor, numpy.ndarray, list][源代码]

Map input labels into range of 0 to numbers of classes-1.

It is usually used in the meta testing phase, in which the class ids are random sampled and discontinuous.

参数
  • labels (Tensor | np.ndarray | list) – The labels to be wrapped.

  • class_ids (list[int]) – All class ids of labels.

返回

Same type as the input labels.

返回类型

(Tensor | np.ndarray | list)

classification.models

backbones

class mmfewshot.classification.models.backbones.Conv4(depth: int = 4, pooling_blocks: Sequence[int] = (0, 1, 2, 3), padding_blocks: Sequence[int] = (0, 1, 2, 3), flatten: bool = True)[源代码]
class mmfewshot.classification.models.backbones.ConvNet(depth: int, pooling_blocks: Sequence[int], padding_blocks: Sequence[int], flatten: bool = True)[源代码]

Simple ConvNet.

参数
  • depth (int) – The number of ConvBlock.

  • pooling_blocks (Sequence[int]) – Indicate which block to use 2x2 max pooling.

  • padding_blocks (Sequence[int]) – Indicate which block to use conv layer with padding.

  • flatten (bool) – Whether to flatten features from (N, C, H, W) to (N, C*H*W). Default: True.

forward(x: torch.Tensor)torch.Tensor[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class mmfewshot.classification.models.backbones.ResNet12(block: torch.nn.modules.module.Module = <class 'mmfewshot.classification.models.backbones.resnet12.BasicBlock'>, with_avgpool: bool = True, pool_size: Tuple[int, int] = (1, 1), flatten: bool = True, drop_rate: float = 0.0, drop_block_size: int = 5)[源代码]

ResNet12.

参数
  • block (nn.Module) – Block to build layers. Default: BasicBlock.

  • with_avgpool (bool) – Whether to average pool the features. Default: True.

  • pool_size (tuple(int,int)) – The output shape of average pooling layer. Default: (1, 1).

  • flatten (bool) – Whether to flatten features from (N, C, H, W) to (N, C*H*W). Default: True.

  • drop_rate (float) – Dropout rate. Default: 0.0.

  • drop_block_size (int) – Size of drop block. Default: 5.

forward(x: torch.Tensor)torch.Tensor[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class mmfewshot.classification.models.backbones.WRN28x10(depth: int = 28, widen_factor: int = 10, stride: int = 1, drop_rate: float = 0.5, flatten: bool = True, with_avgpool: bool = True, pool_size: Tuple[int, int] = (1, 1))[源代码]
class mmfewshot.classification.models.backbones.WideResNet(depth: int, widen_factor: int = 1, stride: int = 1, drop_rate: float = 0.0, flatten: bool = True, with_avgpool: bool = True, pool_size: Tuple[int, int] = (1, 1))[源代码]

WideResNet.

参数
  • depth (int) – The number of layers.

  • widen_factor (int) – The widen factor of channels. Default: 1.

  • stride (int) – Stride of first layer. Default: 1.

  • drop_rate (float) – Dropout rate. Default: 0.0.

  • with_avgpool (bool) – Whether to average pool the features. Default: True.

  • flatten (bool) – Whether to flatten features from (N, C, H, W) to (N, C*H*W). Default: True.

  • pool_size (tuple(int,int)) – The output shape of average pooling layer. Default: (1, 1).

forward(x: torch.Tensor)torch.Tensor[源代码]

Defines the computation performed at every call.

Should be overridden by all subclasses.

注解

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classifier

heads

class mmfewshot.classification.models.heads.CosineDistanceHead(num_classes: int, in_channels: int, temperature: Optional[float] = None, eps: float = 1e-05, *args, **kwargs)[源代码]

Classification head for `Baseline++ https://arxiv.org/abs/2003.04390`_.

参数
  • num_classes (int) – Number of categories.

  • in_channels (int) – Number of channels in the input feature map.

  • temperature (float | None) – Scaling factor of cls_score. Default: None.

  • eps (float) – Constant variable to avoid division by zero. Default: 0.00001.

before_forward_query()None[源代码]

Used in meta testing.

This function will be called before model forward query data during meta testing.

before_forward_support()None[源代码]

Used in meta testing.

This function will be called before model forward support data during meta testing.

forward_query(x: torch.Tensor, **kwargs)List[源代码]

Forward query data in meta testing.

forward_support(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)Dict[源代码]

Forward support data in meta testing.

forward_train(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)Dict[源代码]

Forward training data.

class mmfewshot.classification.models.heads.LinearHead(num_classes: int, in_channels: int, *args, **kwargs)[源代码]

Classification head for Baseline.

参数
  • num_classes (int) – Number of categories.

  • in_channels (int) – Number of channels in the input feature map.

before_forward_query()None[源代码]

Used in meta testing.

This function will be called before model forward query data during meta testing.

before_forward_support()None[源代码]

Used in meta testing.

This function will be called before model forward support data during meta testing.

forward_query(x: torch.Tensor, **kwargs)List[源代码]

Forward query data in meta testing.

forward_support(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)Dict[源代码]

Forward support data in meta testing.

forward_train(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)Dict[源代码]

Forward training data.

class mmfewshot.classification.models.heads.MatchingHead(temperature: float = 100, loss: Dict = {'loss_weight': 1.0, 'type': 'NLLLoss'}, *args, **kwargs)[源代码]

Classification head for `MatchingNet.

<https://arxiv.org/abs/1606.04080>`_.

Note that this implementation is without FCE(Full Context Embeddings).

参数
  • temperature (float) – The scale factor of cls_score.

  • loss (dict) – Config of training loss.

before_forward_query()None[源代码]

Used in meta testing.

This function will be called before model forward query data during meta testing.

before_forward_support()None[源代码]

Used in meta testing.

This function will be called before model forward support data during meta testing.

forward_query(x: torch.Tensor, **kwargs)List[源代码]

Forward query data in meta testing.

forward_support(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)None[源代码]

Forward support data in meta testing.

forward_train(support_feats: torch.Tensor, support_labels: torch.Tensor, query_feats: torch.Tensor, query_labels: torch.Tensor, **kwargs)Dict[源代码]

Forward training data.

参数
  • support_feats (Tensor) – Features of support data with shape (N, C).

  • support_labels (Tensor) – Labels of support data with shape (N).

  • query_feats (Tensor) – Features of query data with shape (N, C).

  • query_labels (Tensor) – Labels of query data with shape (N).

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

class mmfewshot.classification.models.heads.MetaBaselineHead(temperature: float = 10.0, learnable_temperature: bool = True, *args, **kwargs)[源代码]

Classification head for `MetaBaseline https://arxiv.org/abs/2003.04390`_.

参数
  • temperature (float) – Scaling factor of cls_score. Default: 10.0.

  • learnable_temperature (bool) – Whether to use learnable scale factor or not. Default: True.

before_forward_query()None[源代码]

Used in meta testing.

This function will be called before model forward query data during meta testing.

before_forward_support()None[源代码]

Used in meta testing.

This function will be called before model forward support data during meta testing.

forward_query(x: torch.Tensor, **kwargs)List[源代码]

Forward query data in meta testing.

forward_support(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)None[源代码]

Forward support data in meta testing.

forward_train(support_feats: torch.Tensor, support_labels: torch.Tensor, query_feats: torch.Tensor, query_labels: torch.Tensor, **kwargs)Dict[源代码]

Forward training data.

参数
  • support_feats (Tensor) – Features of support data with shape (N, C).

  • support_labels (Tensor) – Labels of support data with shape (N).

  • query_feats (Tensor) – Features of query data with shape (N, C).

  • query_labels (Tensor) – Labels of query data with shape (N).

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

class mmfewshot.classification.models.heads.NegMarginHead(num_classes: int, in_channels: int, temperature: float = 30.0, margin: float = 0.0, metric_type: str = 'cosine', *args, **kwargs)[源代码]

Classification head for NegMargin.

参数
  • num_classes (int) – Number of categories.

  • in_channels (int) – Number of channels in the input feature map.

  • temperature (float) – Scaling factor of cls_score. Default: 30.0.

  • margin (float) – Margin of cls_score. Default: 0.0.

  • metric_type (str) – The way to calculate similarity. Options:[‘cosine’, ‘softmax’]. Default: ‘cosine’

before_forward_query()None[源代码]

Used in meta testing.

This function will be called before model forward query data during meta testing.

before_forward_support()None[源代码]

Used in meta testing.

This function will be called before model forward support data during meta testing.

forward_query(x: torch.Tensor, **kwargs)List[源代码]

Forward query data in meta testing.

forward_support(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)Dict[源代码]

Forward support data in meta testing.

forward_train(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)Dict[源代码]

Forward training data.

class mmfewshot.classification.models.heads.PrototypeHead(*args, **kwargs)[源代码]

Classification head for `ProtoNet.

<https://arxiv.org/abs/1703.05175>`_.

before_forward_query()None[源代码]

Used in meta testing.

This function will be called before model forward query data during meta testing.

before_forward_support()None[源代码]

Used in meta testing.

This function will be called before model forward support data during meta testing.

forward_query(x: torch.Tensor, **kwargs)List[源代码]

Forward query data in meta testing.

forward_support(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)None[源代码]

Forward support data in meta testing.

forward_train(support_feats: torch.Tensor, support_labels: torch.Tensor, query_feats: torch.Tensor, query_labels: torch.Tensor, **kwargs)Dict[源代码]

Forward training data.

参数
  • support_feats (Tensor) – Features of support data with shape (N, C).

  • support_labels (Tensor) – Labels of support data with shape (N).

  • query_feats (Tensor) – Features of query data with shape (N, C).

  • query_labels (Tensor) – Labels of query data with shape (N).

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

class mmfewshot.classification.models.heads.RelationHead(in_channels: int, feature_size: Tuple[int] = (7, 7), hidden_channels: int = 8, loss: Dict = {'loss_weight': 1.0, 'type': 'CrossEntropyLoss'}, *args, **kwargs)[源代码]

Classification head for `RelationNet.

<https://arxiv.org/abs/1711.06025>`_.

参数
  • in_channels (int) – Number of channels in the input feature map.

  • feature_size (tuple(int, int)) – Size of the input feature map. Default: (7, 7).

  • hidden_channels (int) – Number of channels for the hidden fc layer. Default: 8.

  • loss (dict) – Training loss. Options are CrossEntropyLoss and MSELoss.

before_forward_query()None[源代码]

Used in meta testing.

This function will be called before model forward query data during meta testing.

before_forward_support()None[源代码]

Used in meta testing.

This function will be called before model forward support data during meta testing.

forward_query(x: torch.Tensor, **kwargs)List[源代码]

Forward query data in meta testing.

forward_relation_module(x: torch.Tensor)torch.Tensor[源代码]

Forward function for relation module.

forward_support(x: torch.Tensor, gt_label: torch.Tensor, **kwargs)None[源代码]

Forward support data in meta testing.

forward_train(support_feats: torch.Tensor, support_labels: torch.Tensor, query_feats: torch.Tensor, query_labels: torch.Tensor, **kwargs)Dict[源代码]

Forward training data.

参数
  • support_feats (Tensor) – Features of support data with shape (N, C, H, W).

  • support_labels (Tensor) – Labels of support data with shape (N).

  • query_feats (Tensor) – Features of query data with shape (N, C, H, W).

  • query_labels (Tensor) – Labels of query data with shape (N).

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

init_weights()None[源代码]

Initialize the weights.

losses

class mmfewshot.classification.models.losses.MSELoss(reduction: typing_extensions.Literal[none, mean, sum] = 'mean', loss_weight: float = 1.0)[源代码]

MSELoss.

参数
  • reduction (str) – The method that reduces the loss to a scalar. Options are “none”, “mean” and “sum”. Default: ‘mean’.

  • loss_weight (float) – The weight of the loss. Default: 1.0.

forward(pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, avg_factor: Optional[Union[float, int]] = None, reduction_override: Optional[str] = None)torch.Tensor[源代码]

Forward function of loss.

参数
  • pred (Tensor) – The prediction with shape (N, *), where * means any number of additional dimensions.

  • target (Tensor) – The learning target of the prediction with shape (N, *) same as the input.

  • weight (Tensor | None) – Weight of the loss for each prediction. Default: None.

  • avg_factor (float | int | None) – Average factor that is used to average the loss. Default: None.

  • reduction_override (str | None) – The reduction method used to override the original reduction method of the loss. Options are “none”, “mean” and “sum”. Default: None.

返回

The calculated loss

返回类型

Tensor

class mmfewshot.classification.models.losses.NLLLoss(reduction: typing_extensions.Literal[none, mean, sum] = 'mean', loss_weight: float = 1.0)[源代码]

NLLLoss.

参数
  • reduction (str) – The method that reduces the loss to a scalar. Options are “none”, “mean” and “sum”. Default: ‘mean’.

  • loss_weight (float) – The weight of the loss. Default: 1.0.

forward(pred: torch.Tensor, target: torch.Tensor, weight: Optional[torch.Tensor] = None, avg_factor: Optional[Union[float, int]] = None, reduction_override: Optional[str] = None)torch.Tensor[源代码]

Forward function of loss.

参数
  • pred (Tensor) – The prediction with shape (N, C).

  • target (Tensor) – The learning target of the prediction. with shape (N, 1).

  • weight (Tensor | None) – Weight of the loss for each prediction. Default: None.

  • avg_factor (float | int | None) – Average factor that is used to average the loss. Default: None.

  • reduction_override (str | None) – The reduction method used to override the original reduction method of the loss. Options are “none”, “mean” and “sum”. Default: None.

返回

The calculated loss

返回类型

Tensor

utils

mmfewshot.classification.models.utils.convert_maml_module(module: torch.nn.modules.module.Module)torch.nn.modules.module.Module[源代码]

Convert a normal model to MAML model.

Replace nn.Linear with LinearWithFastWeight, nn.Conv2d with Conv2dWithFastWeight and BatchNorm2d with BatchNorm2dWithFastWeight.

参数

module (nn.Module) – The module to be converted.

Returns :

nn.Module: A MAML module.

classification.utils

class mmfewshot.classification.utils.MetaTestParallel(module: torch.nn.modules.module.Module, dim: int = 0)[源代码]

The MetaTestParallel module that supports DataContainer.

Note that each task is tested on a single GPU. Thus the data and model on different GPU should be independent. MMDistributedDataParallel always automatically synchronizes the grad in different GPUs when doing the loss backward, which can not meet the requirements. Thus we simply copy the module and wrap it with an MetaTestParallel, which will send data to the device model.

MetaTestParallel has two main differences with PyTorch DataParallel:

  • It supports a custom type DataContainer which allows more flexible control of input data during both GPU and CPU inference.

  • It implement three more APIs before_meta_test(), before_forward_support() and before_forward_query().

参数
  • module (nn.Module) – Module to be encapsulated.

  • dim (int) – Dimension used to scatter the data. Defaults to 0.

forward(*inputs, **kwargs)[源代码]

Override the original forward function.

The main difference lies in the CPU inference where the data in DataContainers will still be gathered.

mmfewshot.detection

detection.apis

mmfewshot.detection.apis.inference_detector(model: torch.nn.modules.module.Module, imgs: Union[List[str], str])List[源代码]

Inference images with the detector.

参数
  • model (nn.Module) – Detector.

  • imgs (list[str] | str) – Batch or single image file.

返回

If imgs is a list or tuple, the same length list type results

will be returned, otherwise return the detection results directly.

返回类型

list

mmfewshot.detection.apis.init_detector(config: Union[str, mmcv.utils.config.Config], checkpoint: Optional[str] = None, device: str = 'cuda:0', cfg_options: Optional[Dict] = None, classes: Optional[List[str]] = None)torch.nn.modules.module.Module[源代码]

Prepare a detector from config file.

参数
  • config (str | mmcv.Config) – Config file path or the config object.

  • checkpoint (str | None) – Checkpoint path. If left as None, the model will not load any weights.

  • device (str) – Runtime device. Default: ‘cuda:0’.

  • cfg_options (dict | None) – Options to override some settings in the used config.

  • classes (list[str] | None) – Options to override classes name of model. Default: None.

返回

The constructed detector.

返回类型

nn.Module

mmfewshot.detection.apis.multi_gpu_model_init(model: torch.nn.modules.module.Module, data_loader: torch.utils.data.dataloader.DataLoader)List[源代码]

Forward support images for meta-learning based detector initialization.

The function usually will be called before single_gpu_test in QuerySupportEvalHook. It firstly forwards support images with mode=model_init and the features will be saved in the model. Then it will call :func:model_init to process the extracted features of support images to finish the model initialization.

Noted that the data_loader should NOT use distributed sampler, all the models in different gpus should be initialized with same images.

参数
  • model (nn.Module) – Model used for extracting support template features.

  • data_loader (nn.Dataloader) – Pytorch data loader.

返回

Extracted support template features.

返回类型

list[Tensor]

mmfewshot.detection.apis.multi_gpu_test(model: torch.nn.modules.module.Module, data_loader: torch.utils.data.dataloader.DataLoader, tmpdir: Optional[str] = None, gpu_collect: bool = False)List[源代码]

Test model with multiple gpus for meta-learning based detector.

The model forward function requires mode, while in mmdet it requires return_loss. And the encode_mask_results is removed. This method tests model with multiple gpus and collects the results under two different modes: gpu and cpu modes. By setting ‘gpu_collect=True’ it encodes results to gpu tensors and use gpu communication for results collection. On cpu mode it saves the results on different gpus to ‘tmpdir’ and collects them by the rank 0 worker.

参数
  • model (nn.Module) – Model to be tested.

  • data_loader (Dataloader) – Pytorch data loader.

  • tmpdir (str) – Path of directory to save the temporary results from different gpus under cpu mode. Default: None.

  • gpu_collect (bool) – Option to use either gpu or cpu to collect results. Default: False.

返回

The prediction results.

返回类型

list

mmfewshot.detection.apis.process_support_images(model: torch.nn.modules.module.Module, support_imgs: List[str], support_labels: List[List[str]], support_bboxes: Optional[List[List[float]]] = None, classes: Optional[List[str]] = None)None[源代码]

Process support images for query support detector.

参数
  • model (nn.Module) – Detector.

  • support_imgs (list[str]) – Support image filenames.

  • support_labels (list[list[str]]) – Support labels of each bbox.

  • support_bboxes (list[list[list[float]]] | None) – Bbox in support images. If it set to None, it will use the [0, 0, image width, image height] as bbox. Default: None.

  • classes (list[str] | None) – Options to override classes name of model. Default: None.

mmfewshot.detection.apis.single_gpu_model_init(model: torch.nn.modules.module.Module, data_loader: torch.utils.data.dataloader.DataLoader)List[源代码]

Forward support images for meta-learning based detector initialization.

The function usually will be called before single_gpu_test in QuerySupportEvalHook. It firstly forwards support images with mode=model_init and the features will be saved in the model. Then it will call :func:model_init to process the extracted features of support images to finish the model initialization.

参数
  • model (nn.Module) – Model used for extracting support template features.

  • data_loader (nn.Dataloader) – Pytorch data loader.

返回

Extracted support template features.

返回类型

list[Tensor]

mmfewshot.detection.apis.single_gpu_test(model: torch.nn.modules.module.Module, data_loader: torch.utils.data.dataloader.DataLoader, show: bool = False, out_dir: Optional[str] = None, show_score_thr: float = 0.3)List[源代码]

Test model with single gpu for meta-learning based detector.

The model forward function requires mode, while in mmdet it requires return_loss. And the encode_mask_results is removed.

参数
  • model (nn.Module) – Model to be tested.

  • data_loader (DataLoader) – Pytorch data loader.

  • show (bool) – Whether to show the image. Default: False.

  • out_dir (str | None) – The directory to write the image. Default: None.

  • show_score_thr (float) – Minimum score of bboxes to be shown. Default: 0.3.

返回

The prediction results.

返回类型

list

detection.core

evaluation

class mmfewshot.detection.core.evaluation.QuerySupportDistEvalHook(model_init_dataloader: torch.utils.data.dataloader.DataLoader, val_dataloader: torch.utils.data.dataloader.DataLoader, **eval_kwargs)[源代码]

Distributed evaluation hook for query support data pipeline.

This hook will first traverse model_init_dataloader to extract support features for model initialization and then evaluate the data from val_dataloader.

Noted that model_init_dataloader should NOT use distributed sampler to make all the models on different gpus get same data results in same initialized models.

参数
  • model_init_dataloader (DataLoader) – A PyTorch dataloader of model_init dataset.

  • val_dataloader (DataLoader) – A PyTorch dataloader of dataset to be evaluated.

  • **eval_kwargs – Evaluation arguments fed into the evaluate function of the dataset.

class mmfewshot.detection.core.evaluation.QuerySupportEvalHook(model_init_dataloader: torch.utils.data.dataloader.DataLoader, val_dataloader: torch.utils.data.dataloader.DataLoader, **eval_kwargs)[源代码]

Evaluation hook for query support data pipeline.

This hook will first traverse model_init_dataloader to extract support features for model initialization and then evaluate the data from val_dataloader.

参数
  • model_init_dataloader (DataLoader) – A PyTorch dataloader of model_init dataset.

  • val_dataloader (DataLoader) – A PyTorch dataloader of dataset to be evaluated.

  • **eval_kwargs – Evaluation arguments fed into the evaluate function of the dataset.

mmfewshot.detection.core.evaluation.eval_map(det_results: List[List[numpy.ndarray]], annotations: List[Dict], classes: List[str], scale_ranges: Optional[List[Tuple]] = None, iou_thr: float = 0.5, dataset: Optional[Union[str, List[str]]] = None, logger: Optional[object] = None, tpfp_fn: Optional[callable] = None, nproc: int = 4, use_legacy_coordinate: bool = False)Tuple[List, List[Dict]][源代码]

Evaluate mAP of a dataset.

eval_map() in mmdet predefines the names of classes and thus not supports report map results of arbitrary class splits.

参数
  • det_results (list[list[np.ndarray]] | list[tuple[np.ndarray]]) – The outer list indicates images, and the inner list indicates per-class detected bboxes.

  • annotations (list[dict]) –

    Ground truth annotations where each item of the list indicates an image. Keys of annotations are:

    • bboxes: numpy array of shape (n, 4)

    • labels: numpy array of shape (n, )

    • bboxes_ignore (optional): numpy array of shape (k, 4)

    • labels_ignore (optional): numpy array of shape (k, )

  • classes (list[str]) – Names of class.

  • scale_ranges (list[tuple] | None) – Range of scales to be evaluated, in the format [(min1, max1), (min2, max2), …]. A range of (32, 64) means the area range between (32**2, 64**2). Default: None.

  • iou_thr (float) – IoU threshold to be considered as matched. Default: 0.5.

  • dataset (list[str] | str | None) – Dataset name or dataset classes, there are minor differences in metrics for different datasets, e.g. “voc07”, “imagenet_det”, etc. Default: None.

  • logger (logging.Logger | None) – The way to print the mAP summary. See mmcv.utils.print_log() for details. Default: None.

  • tpfp_fn (callable | None) – The function used to determine true false positives. If None, tpfp_default() is used as default unless dataset is ‘det’ or ‘vid’ (tpfp_imagenet() in this case). If it is given as a function, then this function is used to evaluate tp & fp. Default None.

  • nproc (int) – Processes used for computing TP and FP. Default: 4.

  • use_legacy_coordinate (bool) – Whether to use coordinate system in mmdet v1.x. which means width, height should be calculated as ‘x2 - x1 + 1` and ‘y2 - y1 + 1’ respectively. Default: False.

返回

(list, [dict, dict, …])

返回类型

tuple

utils

class mmfewshot.detection.core.utils.ContrastiveLossDecayHook(decay_steps: Sequence[int], decay_rate: float = 0.5)[源代码]

Hook for contrast loss weight decay used in FSCE.

参数
  • decay_steps (list[int] | tuple[int]) – Each item in the list is the step to decay the loss weight.

  • decay_rate (float) – Decay rate. Default: 0.5.

detection.datasets

class mmfewshot.detection.datasets.BaseFewShotDataset(ann_cfg: List[Dict], classes: Optional[Union[str, Sequence[str]]], pipeline: Optional[List[Dict]] = None, multi_pipelines: Optional[Dict[str, List[Dict]]] = None, data_root: Optional[str] = None, img_prefix: str = '', seg_prefix: Optional[str] = None, proposal_file: Optional[str] = None, test_mode: bool = False, filter_empty_gt: bool = True, min_bbox_size: Optional[Union[float, int]] = None, ann_shot_filter: Optional[Dict] = None, instance_wise: bool = False, dataset_name: Optional[str] = None)[源代码]

Base dataset for few shot detection.

The main differences with normal detection dataset fall in two aspects.

  • It allows to specify single (used in normal dataset) or multiple

    (used in query-support dataset) pipelines for data processing.

  • It supports to control the maximum number of instances of each class

    when loading the annotation file.

The annotation format is shown as follows. The ann field is optional for testing.

[
    {
        'id': '0000001'
        'filename': 'a.jpg',
        'width': 1280,
        'height': 720,
        'ann': {
            'bboxes': <np.ndarray> (n, 4) in (x1, y1, x2, y2) order.
            'labels': <np.ndarray> (n, ),
            'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
            'labels_ignore': <np.ndarray> (k, 4) (optional field)
        }
    },
    ...
]
参数
  • ann_cfg (list[dict]) –

    Annotation config support two type of config.

    • loading annotation from common ann_file of dataset with or without specific classes. example:dict(type=’ann_file’, ann_file=’path/to/ann_file’, ann_classes=[‘dog’, ‘cat’])

    • loading annotation from a json file saved by dataset. example:dict(type=’saved_dataset’, ann_file=’path/to/ann_file’)

  • classes (str | Sequence[str] | None) – Classes for model training and provide fixed label for each class.

  • pipeline (list[dict] | None) – Config to specify processing pipeline. Used in normal dataset. Default: None.

  • multi_pipelines (dict[list[dict]]) –

    Config to specify data pipelines for corresponding data flow. For example, query and support data can be processed with two different pipelines, the dict should contain two keys like:

    • query (list[dict]): Config for query-data process pipeline.

    • support (list[dict]): Config for support-data process pipeline.

  • data_root (str | None) – Data root for ann_cfg, img_prefix`, seg_prefix, proposal_file if specified. Default: None.

  • test_mode (bool) – If set True, annotation will not be loaded. Default: False.

  • filter_empty_gt (bool) – If set true, images without bounding boxes of the dataset’s classes will be filtered out. This option only works when test_mode=False, i.e., we never filter images during tests. Default: True.

  • min_bbox_size (int | float | None) – The minimum size of bounding boxes in the images. If the size of a bounding box is less than min_bbox_size, it would be added to ignored field. Default: None.

  • ann_shot_filter (dict | None) – Used to specify the class and the corresponding maximum number of instances when loading the annotation file. For example: {‘dog’: 10, ‘person’: 5}. If set it as None, all annotation from ann file would be loaded. Default: None.

  • instance_wise (bool) – If set true, self.data_infos would change to instance-wise, which means if the annotation of single image has more than one instance, the annotation would be split to num_instances items. Often used in support datasets, Default: False.

  • dataset_name (str | None) – Name of dataset to display. For example: ‘train_dataset’ or ‘query_dataset’. Default: None.

ann_cfg_parser(ann_cfg: List[Dict])List[Dict][源代码]

Parse annotation config to annotation information.

参数

ann_cfg (list[dict]) –

Annotation config support two type of config.

  • ’ann_file’: loading annotation from common ann_file of

    dataset. example: dict(type=’ann_file’, ann_file=’path/to/ann_file’, ann_classes=[‘dog’, ‘cat’])

  • ’saved_dataset’: loading annotation from saved dataset.

    example:dict(type=’saved_dataset’, ann_file=’path/to/ann_file’)

返回

Annotation information.

返回类型

list[dict]

get_ann_info(idx: int)Dict[源代码]

Get annotation by index.

When override this function please make sure same annotations are used during the whole training.

参数

idx (int) – Index of data.

返回

Annotation info of specified index.

返回类型

dict

load_annotations_saved(ann_file: str)List[Dict][源代码]

Load data_infos from saved json.

prepare_train_img(idx: int, pipeline_key: Optional[str] = None, gt_idx: Optional[List[int]] = None)Dict[源代码]

Get training data and annotations after pipeline.

参数
  • idx (int) – Index of data.

  • pipeline_key (str) – Name of pipeline

  • gt_idx (list[int]) – Index of used annotation.

返回

Training data and annotation after pipeline with new keys introduced by pipeline.

返回类型

dict

save_data_infos(output_path: str)None[源代码]

Save data_infos into json.

class mmfewshot.detection.datasets.CropResizeInstance(num_context_pixels: int = 16, target_size: Tuple[int] = (320, 320))[源代码]

Crop and resize instance according to bbox form image.

参数
  • num_context_pixels (int) – Padding pixel around instance. Default: 16.

  • target_size (tuple[int, int]) – Resize cropped instance to target size. Default: (320, 320).

class mmfewshot.detection.datasets.FewShotCocoDataset(classes: Optional[Union[Sequence[str], str]] = None, num_novel_shots: Optional[int] = None, num_base_shots: Optional[int] = None, ann_shot_filter: Optional[Dict[str, int]] = None, min_bbox_area: Optional[Union[float, int]] = None, dataset_name: Optional[str] = None, test_mode: bool = False, **kwargs)[源代码]

COCO dataset for few shot detection.

参数
  • classes (str | Sequence[str] | None) – Classes for model training and provide fixed label for each class. When classes is string, it will load pre-defined classes in FewShotCocoDataset. For example: ‘BASE_CLASSES’, ‘NOVEL_CLASSES` or ALL_CLASSES.

  • num_novel_shots (int | None) – Max number of instances used for each novel class. If is None, all annotation will be used. Default: None.

  • num_base_shots (int | None) – Max number of instances used for each base class. If is None, all annotation will be used. Default: None.

  • ann_shot_filter (dict | None) – Used to specify the class and the corresponding maximum number of instances when loading the annotation file. For example: {‘dog’: 10, ‘person’: 5}. If set it as None, ann_shot_filter will be created according to num_novel_shots and num_base_shots.

  • min_bbox_area (int | float | None) – Filter images with bbox whose area smaller min_bbox_area. If set to None, skip this filter. Default: None.

  • dataset_name (str | None) – Name of dataset to display. For example: ‘train dataset’ or ‘query dataset’. Default: None.

  • test_mode (bool) – If set True, annotation will not be loaded. Default: False.

evaluate(results: List[Sequence], metric: Union[str, List[str]] = 'bbox', logger: Optional[object] = None, jsonfile_prefix: Optional[str] = None, classwise: bool = False, proposal_nums: Sequence[int] = (100, 300, 1000), iou_thrs: Optional[Union[float, Sequence[float]]] = None, metric_items: Optional[Union[str, List[str]]] = None, class_splits: Optional[List[str]] = None)Dict[源代码]

Evaluation in COCO protocol and summary results of different splits of classes.

参数
  • results (list[list | tuple]) – Testing results of the dataset.

  • metric (str | list[str]) – Metrics to be evaluated. Options are ‘bbox’, ‘proposal’, ‘proposal_fast’. Default: ‘bbox’

  • logger (logging.Logger | None) – Logger used for printing related information during evaluation. Default: None.

  • jsonfile_prefix (str | None) – The prefix of json files. It includes the file path and the prefix of filename, e.g., “a/b/prefix”. If not specified, a temp file will be created. Default: None.

  • classwise (bool) – Whether to evaluating the AP for each class.

  • proposal_nums (Sequence[int]) – Proposal number used for evaluating recalls, such as recall@100, recall@1000. Default: (100, 300, 1000).

  • iou_thrs (Sequence[float] | float | None) – IoU threshold used for evaluating recalls/mAPs. If set to a list, the average of all IoUs will also be computed. If not specified, [0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used. Default: None.

  • metric_items (list[str] | str | None) – Metric items that will be returned. If not specified, ['AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ] will be used when metric=='proposal', ['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'] will be used when metric=='bbox'.

  • class_splits – (list[str] | None): Calculate metric of classes split in COCO_SPLIT. For example: [‘BASE_CLASSES’, ‘NOVEL_CLASSES’]. Default: None.

返回

COCO style evaluation metric.

返回类型

dict[str, float]

get_cat_ids(idx: int)List[int][源代码]

Get category ids by index.

Overwrite the function in CocoDataset.

参数

idx (int) – Index of data.

返回

All categories in the image of specified index.

返回类型

list[int]

get_classes(classes: Union[str, Sequence[str]])List[str][源代码]

Get class names.

It supports to load pre-defined classes splits. The pre-defined classes splits are: [‘ALL_CLASSES’, ‘NOVEL_CLASSES’, ‘BASE_CLASSES’]

参数

classes (str | Sequence[str]) – Classes for model training and provide fixed label for each class. When classes is string, it will load pre-defined classes in FewShotCocoDataset. For example: ‘NOVEL_CLASSES’.

返回

list of class names.

返回类型

list[str]

load_annotations(ann_cfg: List[Dict])List[Dict][源代码]

Support to Load annotation from two type of ann_cfg.

  • type of ‘ann_file’: COCO-style annotation file.

  • type of ‘saved_dataset’: Saved COCO dataset json.

参数

ann_cfg (list[dict]) – Config of annotations.

返回

Annotation infos.

返回类型

list[dict]

load_annotations_coco(ann_file: str)List[Dict][源代码]

Load annotation from COCO style annotation file.

参数

ann_file (str) – Path of annotation file.

返回

Annotation info from COCO api.

返回类型

list[dict]

class mmfewshot.detection.datasets.FewShotVOCDataset(classes: Optional[Union[Sequence[str], str]] = None, num_novel_shots: Optional[int] = None, num_base_shots: Optional[int] = None, ann_shot_filter: Optional[Dict] = None, use_difficult: bool = False, min_bbox_area: Optional[Union[float, int]] = None, dataset_name: Optional[str] = None, test_mode: bool = False, coordinate_offset: List[int] = [- 1, - 1, 0, 0], **kwargs)[源代码]

VOC dataset for few shot detection.

参数
  • classes (str | Sequence[str]) – Classes for model training and provide fixed label for each class. When classes is string, it will load pre-defined classes in FewShotVOCDataset. For example: ‘NOVEL_CLASSES_SPLIT1’.

  • num_novel_shots (int | None) – Max number of instances used for each novel class. If is None, all annotation will be used. Default: None.

  • num_base_shots (int | None) – Max number of instances used for each base class. When it is None, all annotations will be used. Default: None.

  • ann_shot_filter (dict | None) – Used to specify the class and the corresponding maximum number of instances when loading the annotation file. For example: {‘dog’: 10, ‘person’: 5}. If set it as None, ann_shot_filter will be created according to num_novel_shots and num_base_shots. Default: None.

  • use_difficult (bool) – Whether use the difficult annotation or not. Default: False.

  • min_bbox_area (int | float | None) – Filter images with bbox whose area smaller min_bbox_area. If set to None, skip this filter. Default: None.

  • dataset_name (str | None) – Name of dataset to display. For example: ‘train dataset’ or ‘query dataset’. Default: None.

  • test_mode (bool) – If set True, annotation will not be loaded. Default: False.

  • coordinate_offset (list[int]) – The bbox annotation will add the coordinate offsets which corresponds to [x_min, y_min, x_max, y_max] during training. For testing, the gt annotation will not be changed while the predict results will minus the coordinate offsets to inverse data loading logic in training. Default: [-1, -1, 0, 0].

evaluate(results: List[Sequence], metric: Union[str, List[str]] = 'mAP', logger: Optional[object] = None, proposal_nums: Sequence[int] = (100, 300, 1000), iou_thr: Optional[Union[float, Sequence[float]]] = 0.5, class_splits: Optional[List[str]] = None)Dict[源代码]

Evaluation in VOC protocol and summary results of different splits of classes.

参数
  • results (list[list | tuple]) – Predictions of the model.

  • metric (str | list[str]) – Metrics to be evaluated. Options are ‘mAP’, ‘recall’. Default: mAP.

  • logger (logging.Logger | None) – Logger used for printing related information during evaluation. Default: None.

  • proposal_nums (Sequence[int]) – Proposal number used for evaluating recalls, such as recall@100, recall@1000. Default: (100, 300, 1000).

  • iou_thr (float | list[float]) – IoU threshold. Default: 0.5.

  • class_splits – (list[str] | None): Calculate metric of classes split defined in VOC_SPLIT. For example: [‘BASE_CLASSES_SPLIT1’, ‘NOVEL_CLASSES_SPLIT1’]. Default: None.

返回

AP/recall metrics.

返回类型

dict[str, float]

get_classes(classes: Union[str, Sequence[str]])List[str][源代码]

Get class names.

It supports to load pre-defined classes splits. The pre-defined classes splits are: [‘ALL_CLASSES_SPLIT1’, ‘ALL_CLASSES_SPLIT2’, ‘ALL_CLASSES_SPLIT3’,

‘BASE_CLASSES_SPLIT1’, ‘BASE_CLASSES_SPLIT2’, ‘BASE_CLASSES_SPLIT3’, ‘NOVEL_CLASSES_SPLIT1’,’NOVEL_CLASSES_SPLIT2’,’NOVEL_CLASSES_SPLIT3’]

参数

classes (str | Sequence[str]) – Classes for model training and provide fixed label for each class. When classes is string, it will load pre-defined classes in FewShotVOCDataset. For example: ‘NOVEL_CLASSES_SPLIT1’.

返回

List of class names.

返回类型

list[str]

load_annotations(ann_cfg: List[Dict])List[Dict][源代码]

Support to load annotation from two type of ann_cfg.

参数
  • ann_cfg (list[dict]) – Support two type of config.

  • loading annotation from common ann_file of dataset (-) – with or without specific classes. example:dict(type=’ann_file’, ann_file=’path/to/ann_file’, ann_classes=[‘dog’, ‘cat’])

  • loading annotation from a json file saved by dataset. (-) – example:dict(type=’saved_dataset’, ann_file=’path/to/ann_file’)

返回

Annotation information.

返回类型

list[dict]

load_annotations_xml(ann_file: str, classes: Optional[List[str]] = None)List[Dict][源代码]

Load annotation from XML style ann_file.

It supports using image id or image path as image names to load the annotation file.

参数
  • ann_file (str) – Path of annotation file.

  • classes (list[str] | None) – Specific classes to load form xml file. If set to None, it will use classes of whole dataset. Default: None.

返回

Annotation info from XML file.

返回类型

list[dict]

class mmfewshot.detection.datasets.GenerateMask(target_size: Tuple[int] = (224, 224))[源代码]

Resize support image and generate a mask.

参数

target_size (tuple[int, int]) – Crop and resize to target size. Default: (224, 224).

class mmfewshot.detection.datasets.NWayKShotDataloader(query_data_loader: torch.utils.data.dataloader.DataLoader, support_data_loader: torch.utils.data.dataloader.DataLoader)[源代码]

A dataloader wrapper.

It Create a iterator to generate query and support batch simultaneously. Each batch contains query data and support data, and the lengths are batch_size and (num_support_ways * num_support_shots) respectively.

参数
  • query_data_loader (DataLoader) – DataLoader of query dataset

  • support_data_loader (DataLoader) – DataLoader of support datasets.

class mmfewshot.detection.datasets.NWayKShotDataset(query_dataset: mmfewshot.detection.datasets.base.BaseFewShotDataset, support_dataset: Optional[mmfewshot.detection.datasets.base.BaseFewShotDataset], num_support_ways: int, num_support_shots: int, one_support_shot_per_image: bool = False, num_used_support_shots: int = 200, repeat_times: int = 1)[源代码]

A dataset wrapper of NWayKShotDataset.

Building NWayKShotDataset requires query and support dataset, the behavior of NWayKShotDataset is determined by mode. When dataset in ‘query’ mode, dataset will return regular image and annotations. While dataset in ‘support’ mode, dataset will build batch indices firstly and each batch indices contain (num_support_ways * num_support_shots) samples. In other words, for support mode every call of __getitem__ will return a batch of samples, therefore the outside dataloader should set batch_size to 1. The default mode of NWayKShotDataset is ‘query’ and by using convert function convert_query_to_support the mode will be converted into ‘support’.

参数
  • query_dataset (BaseFewShotDataset) – Query dataset to be wrapped.

  • support_dataset (BaseFewShotDataset | None) – Support dataset to be wrapped. If support dataset is None, support dataset will copy from query dataset.

  • num_support_ways (int) – Number of classes for support in mini-batch.

  • num_support_shots (int) – Number of support shot for each class in mini-batch.

  • one_support_shot_per_image (bool) – If True only one annotation will be sampled from each image. Default: False.

  • num_used_support_shots (int | None) – The total number of support shots sampled and used for each class during training. If set to None, all shots in dataset will be used as support shot. Default: 200.

  • shuffle_support (bool) – If allow generate new batch indices for each epoch. Default: False.

  • repeat_times (int) – The length of repeated dataset will be times larger than the original dataset. Default: 1.

convert_query_to_support(support_dataset_len: int)None[源代码]

Convert query dataset to support dataset.

参数

support_dataset_len (int) – Length of pre sample batch indices.

generate_support_batch_indices(dataset_len: int)List[List[Tuple[int]]][源代码]

Generate batch indices from support dataset.

Batch indices is in the shape of [length of datasets * [support way * support shots]]. And the dataset_len will be the length of support dataset.

参数

dataset_len (int) – Length of batch indices.

返回

Pre-sample batch indices.

返回类型

list[list[(data_idx, gt_idx)]]

get_support_data_infos()List[Dict][源代码]

Get support data infos from batch indices.

save_data_infos(output_path: str)None[源代码]

Save data infos of query and support data.

save_support_data_infos(support_output_path: str)None[源代码]

Save support data infos.

class mmfewshot.detection.datasets.NumpyEncoder(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, sort_keys=False, indent=None, separators=None, default=None)[源代码]

Save numpy array obj to json.

default(obj: object)object[源代码]

Implement this method in a subclass such that it returns a serializable object for o, or calls the base implementation (to raise a TypeError).

For example, to support arbitrary iterators, you could implement default like this:

def default(self, o):
    try:
        iterable = iter(o)
    except TypeError:
        pass
    else:
        return list(iterable)
    # Let the base class default method raise the TypeError
    return JSONEncoder.default(self, o)
class mmfewshot.detection.datasets.QueryAwareDataset(query_dataset: mmfewshot.detection.datasets.base.BaseFewShotDataset, support_dataset: Optional[mmfewshot.detection.datasets.base.BaseFewShotDataset], num_support_ways: int, num_support_shots: int, repeat_times: int = 1)[源代码]

A wrapper of QueryAwareDataset.

Building QueryAwareDataset requires query and support dataset. Every call of __getitem__ will firstly sample a query image and its annotations. Then it will use the query annotations to sample a batch of positive and negative support images and annotations. The positive images share same classes with query, while the annotations of negative images don’t have any category from query.

参数
  • query_dataset (BaseFewShotDataset) – Query dataset to be wrapped.

  • support_dataset (BaseFewShotDataset | None) – Support dataset to be wrapped. If support dataset is None, support dataset will copy from query dataset.

  • num_support_ways (int) – Number of classes for support in mini-batch, the first one always be the positive class.

  • num_support_shots (int) – Number of support shots for each class in mini-batch, the first K shots always from positive class.

  • repeat_times (int) – The length of repeated dataset will be times larger than the original dataset. Default: 1.

generate_support(idx: int, query_class: int, support_classes: List[int])List[Tuple[int]][源代码]

Generate support indices of query images.

参数
  • idx (int) – Index of query data.

  • query_class (int) – Query class.

  • support_classes (list[int]) – Classes of support data.

返回

A mini-batch (num_support_ways *

num_support_shots) of support data (idx, gt_idx).

返回类型

list[tuple(int)]

get_support_data_infos()List[Dict][源代码]

Return data_infos of support dataset.

sample_support_shots(idx: int, class_id: int, allow_same_image: bool = False)List[Tuple[int]][源代码]

Generate support indices according to the class id.

参数
  • idx (int) – Index of query data.

  • class_id (int) – Support class.

  • allow_same_image (bool) – Allow instance sampled from same image as query image. Default: False.

返回

Support data (num_support_shots)

of specific class.

返回类型

list[tuple[int]]

save_data_infos(output_path: str)None[源代码]

Save data_infos into json.

mmfewshot.detection.datasets.build_dataloader(dataset: torch.utils.data.dataset.Dataset, samples_per_gpu: int, workers_per_gpu: int, num_gpus: int = 1, dist: bool = True, shuffle: bool = True, seed: Optional[int] = None, data_cfg: Optional[Dict] = None, use_infinite_sampler: bool = False, **kwargs)torch.utils.data.dataloader.DataLoader[源代码]

Build PyTorch DataLoader.

In distributed training, each GPU/process has a dataloader. In non-distributed training, there is only one dataloader for all GPUs.

参数
  • dataset (Dataset) – A PyTorch dataset.

  • samples_per_gpu (int) – Number of training samples on each GPU, i.e., batch size of each GPU.

  • workers_per_gpu (int) – How many subprocesses to use for data loading for each GPU.

  • num_gpus (int) – Number of GPUs. Only used in non-distributed training. Default:1.

  • dist (bool) – Distributed training/test or not. Default: True.

  • shuffle (bool) – Whether to shuffle the data at every epoch. Default: True.

  • seed (int) – Random seed. Default:None.

  • data_cfg (dict | None) – Dict of data configure. Default: None.

  • use_infinite_sampler (bool) – Whether to use infinite sampler. Noted that infinite sampler will keep iterator of dataloader running forever, which can avoid the overhead of worker initialization between epochs. Default: False.

  • kwargs – any keyword argument to be used to initialize DataLoader

返回

A PyTorch dataloader.

返回类型

DataLoader

mmfewshot.detection.datasets.get_copy_dataset_type(dataset_type: str)str[源代码]

Return corresponding copy dataset type.

detection.models

mmfewshot.detection.models.build_backbone(cfg)[源代码]

Build backbone.

mmfewshot.detection.models.build_detector(cfg: mmcv.utils.config.ConfigDict, logger: Optional[object] = None)[源代码]

Build detector.

mmfewshot.detection.models.build_head(cfg)[源代码]

Build head.

mmfewshot.detection.models.build_loss(cfg)[源代码]

Build loss.

mmfewshot.detection.models.build_neck(cfg)[源代码]

Build neck.

mmfewshot.detection.models.build_roi_extractor(cfg)[源代码]

Build roi extractor.

mmfewshot.detection.models.build_shared_head(cfg)[源代码]

Build shared head.

backbones

class mmfewshot.detection.models.backbones.ResNetWithMetaConv(**kwargs)[源代码]

ResNet with meta_conv to handle different inputs in metarcnn and fsdetview.

When input with shape (N, 3, H, W) from images, the network will use conv1 as regular ResNet. When input with shape (N, 4, H, W) from (image + mask) the network will replace conv1 with meta_conv to handle additional channel.

forward(x: torch.Tensor, use_meta_conv: bool = False)Tuple[torch.Tensor][源代码]

Forward function.

When input with shape (N, 3, H, W) from images, the network will use conv1 as regular ResNet. When input with shape (N, 4, H, W) from (image + mask) the network will replace conv1 with meta_conv to handle additional channel.

参数
  • x (Tensor) – Tensor with shape (N, 3, H, W) from images or (N, 4, H, W) from (images + masks).

  • use_meta_conv (bool) – If set True, forward input tensor with meta_conv which require tensor with shape (N, 4, H, W). Otherwise, forward input tensor with conv1 which require tensor with shape (N, 3, H, W). Default: False.

返回

Tuple of features, each item with

shape (N, C, H, W).

返回类型

tuple[Tensor]

dense_heads

class mmfewshot.detection.models.dense_heads.AttentionRPNHead(num_support_ways: int, num_support_shots: int, aggregation_layer: Dict = {'aggregator_cfgs': [{'type': 'DepthWiseCorrelationAggregator', 'in_channels': 1024, 'with_fc': False}], 'type': 'AggregationLayer'}, roi_extractor: Dict = {'featmap_strides': [16], 'out_channels': 1024, 'roi_layer': {'output_size': 14, 'sampling_ratio': 0, 'type': 'RoIAlign'}, 'type': 'SingleRoIExtractor'}, **kwargs)[源代码]

RPN head for Attention RPN.

参数
  • num_support_ways (int) – Number of sampled classes (pos + neg).

  • num_support_shots (int) – Number of shot for each classes.

  • aggregation_layer (dict) – Config of aggregation_layer.

  • roi_extractor (dict) – Config of roi_extractor.

extract_roi_feat(feats: List[torch.Tensor], rois: torch.Tensor)torch.Tensor[源代码]

Forward function.

参数
  • feats (list[Tensor]) – Input features with shape (N, C, H, W).

  • rois – with shape (m, 5).

forward_train(query_feats: List[torch.Tensor], support_feats: List[torch.Tensor], query_gt_bboxes: List[torch.Tensor], query_img_metas: List[Dict], support_gt_bboxes: List[torch.Tensor], query_gt_bboxes_ignore: Optional[List[torch.Tensor]] = None, proposal_cfg: Optional[mmcv.utils.config.ConfigDict] = None, **kwargs)Tuple[Dict, List[Tuple]][源代码]

Forward function in training phase.

参数
  • query_feats (list[Tensor]) – List of query features, each item with shape (N, C, H, W)..

  • support_feats (list[Tensor]) – List of support features, each item with shape (N, C, H, W).

  • query_gt_bboxes (list[Tensor]) – List of ground truth bboxes of query image, each item with shape (num_gts, 4).

  • query_img_metas (list[dict]) – List of query image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • support_gt_bboxes (list[Tensor]) – List of ground truth bboxes of support image, each item with shape (num_gts, 4).

  • query_gt_bboxes_ignore (list[Tensor]) – List of ground truth bboxes to be ignored of query image with shape (num_ignored_gts, 4). Default: None.

  • proposal_cfg (ConfigDict) – Test / postprocessing configuration. if None, test_cfg would be used. Default: None.

返回

loss components and proposals of each image.

  • losses: (dict[str, Tensor]): A dictionary of loss components.

  • proposal_list (list[Tensor]): Proposals of each image.

返回类型

tuple

loss(cls_scores: List[torch.Tensor], bbox_preds: List[torch.Tensor], gt_bboxes: List[torch.Tensor], img_metas: List[Dict], gt_labels: Optional[List[torch.Tensor]] = None, gt_bboxes_ignore: Optional[List[torch.Tensor]] = None, pair_flags: Optional[List[bool]] = None)Dict[源代码]

Compute losses of rpn head.

参数
  • cls_scores (list[Tensor]) – Box scores for each scale level with shape (N, num_anchors * num_classes, H, W)

  • bbox_preds (list[Tensor]) – Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W)

  • gt_bboxes (list[Tensor]) – Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

  • img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • gt_labels (list[Tensor]) – Class indices corresponding to each box. Default: None.

  • gt_bboxes_ignore (None | list[Tensor]) – Specify which bounding boxes can be ignored when computing the loss. Default: None

  • pair_flags (list[bool]) – Indicate predicted result is from positive pair or negative pair with shape (N). Default: None.

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

simple_test(query_feats: List[torch.Tensor], support_feat: torch.Tensor, query_img_metas: List[Dict], rescale: bool = False)List[torch.Tensor][源代码]

Test function without test time augmentation.

参数
  • query_feats (list[Tensor]) – List of query features, each item with shape(N, C, H, W).

  • support_feat (Tensor) – Support features with shape (N, C, H, W).

  • query_img_metas (list[dict]) – List of query image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • rescale (bool) – Whether to rescale the results. Default: False.

返回

Proposals of each image, each item has shape (n, 5),

where 5 represent (tl_x, tl_y, br_x, br_y, score).

返回类型

List[Tensor]

class mmfewshot.detection.models.dense_heads.TwoBranchRPNHead(mid_channels: int = 64, **kwargs)[源代码]

RPN head for MPSR.

参数

mid_channels (int) – Input channels of rpn_cls_conv. Default: 64.

forward_auxiliary(feats: List[torch.Tensor])List[torch.Tensor][源代码]

Forward auxiliary features at multiple scales.

参数

feats (list[Tensor]) – List of features at multiple scales, each is a 4D-tensor.

返回

Classification scores for all scale levels, each is

a 4D-tensor, the channels number is num_anchors * num_classes.

返回类型

list[Tensor]

forward_auxiliary_single(feat: torch.Tensor)Tuple[torch.Tensor][源代码]

Forward auxiliary feature map of a single scale level.

forward_single(feat: torch.Tensor)Tuple[torch.Tensor, torch.Tensor][源代码]

Forward feature map of a single scale level.

forward_train(x: List[torch.Tensor], auxiliary_rpn_feats: List[torch.Tensor], img_metas: List[Dict], gt_bboxes: List[torch.Tensor], gt_labels: Optional[List[torch.Tensor]] = None, gt_bboxes_ignore: Optional[List[torch.Tensor]] = None, proposal_cfg: Optional[mmcv.utils.config.ConfigDict] = None, **kwargs)Tuple[Dict, List[torch.Tensor]][源代码]
参数
  • x (list[Tensor]) – Features from FPN, each item with shape (N, C, H, W).

  • auxiliary_rpn_feats (list[Tensor]) – Auxiliary features from FPN, each item with shape (N, C, H, W).

  • img_metas (list[dict]) – Meta information of each image, e.g., image size, scaling factor, etc.

  • gt_bboxes (list[Tensor]) – Ground truth bboxes of the image, shape (num_gts, 4).

  • gt_labels (list[Tensor]) – Ground truth labels of each box, shape (num_gts,). Default: None.

  • gt_bboxes_ignore (list[Tensor]) – Ground truth bboxes to be ignored, shape (num_ignored_gts, 4). Default: None.

  • proposal_cfg (ConfigDict) – Test / postprocessing configuration, if None, test_cfg would be used. Default: None.

返回

losses: (dict[str, Tensor]): A dictionary of loss components. proposal_list (List[Tensor]): Proposals of each image.

返回类型

tuple

get_bboxes(cls_scores: List[torch.Tensor], bbox_preds: List[torch.Tensor], img_metas: List[Dict], cfg: Optional[mmcv.utils.config.ConfigDict] = None, rescale: bool = False, with_nms: bool = True)List[torch.Tensor][源代码]

Transform network output for a batch into bbox predictions.

参数
  • cls_scores (list[Tensor]) – Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W)

  • bbox_preds (list[Tensor]) – Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W)

  • img_metas (list[dict]) – Meta information of each image, e.g., image size, scaling factor, etc.

  • cfg (ConfigDict | None) – Test / postprocessing configuration, if None, test_cfg would be used

  • rescale (bool) – If True, return boxes in original image space. Default: False.

  • with_nms (bool) – If True, do nms before return boxes. Default: True.

返回

Proposals of each image, each item has shape (n, 5),

where 5 represent (tl_x, tl_y, br_x, br_y, score).

返回类型

List[Tensor]

loss(cls_scores: List[torch.Tensor], bbox_preds: List[torch.Tensor], gt_bboxes: List[torch.Tensor], gt_labels: List[torch.Tensor], img_metas: List[Dict], gt_bboxes_ignore: Optional[List[torch.Tensor]] = None, auxiliary_cls_scores: Optional[List[torch.Tensor]] = None)Dict[源代码]

Compute losses of the head.

参数
  • cls_scores (list[Tensor]) – Box scores for each scale level, each item with shape (N, num_anchors * num_classes, H, W).

  • bbox_preds (list[Tensor]) – Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W)

  • gt_bboxes (list[Tensor]) – Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

  • gt_labels (list[Tensor]) – class indices corresponding to each box

  • img_metas (list[dict]) – Meta information of each image, e.g., image size, scaling factor, etc.

  • gt_bboxes_ignore (list[Tensor] | None) – specify which bounding boxes can be ignored when computing the loss. Default: None.

  • auxiliary_cls_scores (list[Tensor] | None) – Box scores for each scale level, each item with shape (N, num_anchors * num_classes, H, W). Default: None.

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

loss_bbox_single(bbox_pred: torch.Tensor, anchors: torch.Tensor, bbox_targets: torch.Tensor, bbox_weights: torch.Tensor, num_total_samples: int)Tuple[Dict][源代码]

Compute loss of a single scale level.

参数
  • bbox_pred (Tensor) – Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W).

  • anchors (Tensor) – Box reference for each scale level with shape (N, num_total_anchors, 4).

  • bbox_targets (Tensor) – BBox regression targets of each anchor weight shape (N, num_total_anchors, 4).

  • bbox_weights (Tensor) – BBox regression loss weights of each anchor with shape (N, num_total_anchors, 4).

  • num_total_samples (int) – If sampling, num total samples equal to the number of total anchors; Otherwise, it is the number of positive anchors.

返回

A dictionary of loss components.

返回类型

tuple[dict[str, Tensor]]

loss_cls_single(cls_score: torch.Tensor, labels: torch.Tensor, label_weights: torch.Tensor, num_total_samples: int)Tuple[Dict][源代码]

Compute loss of a single scale level.

参数
  • cls_score (Tensor) – Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W).

  • labels (Tensor) – Labels of each anchors with shape (N, num_total_anchors).

  • label_weights (Tensor) – Label weights of each anchor with shape (N, num_total_anchors)

  • num_total_samples (int) – If sampling, num total samples equal to the number of total anchors; Otherwise, it is the number of positive anchors.

返回

A dictionary of loss components.

返回类型

tuple[dict[str, Tensor]]

detectors

class mmfewshot.detection.models.detectors.AttentionRPNDetector(backbone: mmcv.utils.config.ConfigDict, neck: Optional[mmcv.utils.config.ConfigDict] = None, support_backbone: Optional[mmcv.utils.config.ConfigDict] = None, support_neck: Optional[mmcv.utils.config.ConfigDict] = None, rpn_head: Optional[mmcv.utils.config.ConfigDict] = None, roi_head: Optional[mmcv.utils.config.ConfigDict] = None, train_cfg: Optional[mmcv.utils.config.ConfigDict] = None, test_cfg: Optional[mmcv.utils.config.ConfigDict] = None, pretrained: Optional[mmcv.utils.config.ConfigDict] = None, init_cfg: Optional[mmcv.utils.config.ConfigDict] = None)[源代码]

Implementation of AttentionRPN.

参数
  • backbone (dict) – Config of the backbone for query data.

  • neck (dict | None) – Config of the neck for query data and probably for support data. Default: None.

  • support_backbone (dict | None) – Config of the backbone for support data only. If None, support and query data will share same backbone. Default: None.

  • support_neck (dict | None) – Config of the neck for support data only. Default: None.

  • rpn_head (dict | None) – Config of rpn_head. Default: None.

  • roi_head (dict | None) – Config of roi_head. Default: None.

  • train_cfg (dict | None) – Training config. Useless in CenterNet, but we keep this variable for SingleStageDetector. Default: None.

  • test_cfg (dict | None) – Testing config of CenterNet. Default: None.

  • pretrained (str | None) – model pretrained path. Default: None.

  • init_cfg (dict | list[dict] | None) – Initialization config dict. Default: None.

extract_support_feat(img: torch.Tensor)List[torch.Tensor][源代码]

Extract features of support data.

参数

img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

返回

Features of support images, each item with shape

(N, C, H, W).

返回类型

list[Tensor]

forward_model_init(img: torch.Tensor, img_metas: List[Dict], gt_bboxes: Optional[List[torch.Tensor]] = None, gt_labels: Optional[List[torch.Tensor]] = None, **kwargs)Dict[源代码]

Extract and save support features for model initialization.

参数
  • img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

  • img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • gt_bboxes (list[Tensor]) – Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

  • gt_labels (list[Tensor]) – class indices corresponding to each box.

返回

A dict contains following keys:

  • gt_labels (Tensor): class indices corresponding to each

    feature.

  • res4_roi_feat (Tensor): roi features of res4 layer.

  • res5_roi_feat (Tensor): roi features of res5 layer.

返回类型

dict

model_init()None[源代码]

process the saved support features for model initialization.

simple_test(img: torch.Tensor, img_metas: List[Dict], proposals: Optional[List[torch.Tensor]] = None, rescale: bool = False)List[List[numpy.ndarray]][源代码]

Test without augmentation.

参数
  • img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

  • img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • proposals (list[Tensor] | None) – override rpn proposals with custom proposals. Use when with_rpn is False. Default: None.

  • rescale (bool) – If True, return boxes in original image space.

返回

BBox results of each image and classes.

The outer list corresponds to each image. The inner list corresponds to each class.

返回类型

list[list[np.ndarray]]

class mmfewshot.detection.models.detectors.FSCE(backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None)[源代码]

Implementation of FSCE

class mmfewshot.detection.models.detectors.FSDetView(backbone: mmcv.utils.config.ConfigDict, neck: Optional[mmcv.utils.config.ConfigDict] = None, support_backbone: Optional[mmcv.utils.config.ConfigDict] = None, support_neck: Optional[mmcv.utils.config.ConfigDict] = None, rpn_head: Optional[mmcv.utils.config.ConfigDict] = None, roi_head: Optional[mmcv.utils.config.ConfigDict] = None, train_cfg: Optional[mmcv.utils.config.ConfigDict] = None, test_cfg: Optional[mmcv.utils.config.ConfigDict] = None, pretrained: Optional[mmcv.utils.config.ConfigDict] = None, init_cfg: Optional[mmcv.utils.config.ConfigDict] = None)[源代码]

Implementation of FSDetView.

class mmfewshot.detection.models.detectors.MPSR(rpn_select_levels: List[int], roi_select_levels: List[int], *args, **kwargs)[源代码]

Implementation of MPSR..

参数
  • rpn_select_levels (list[int]) – Specify the corresponding level of fpn features for each scale of image. The selected features will be fed into rpn head.

  • roi_select_levels (list[int]) – Specific which level of fpn features to be selected for each scale of image. The selected features will be fed into roi head.

extract_auxiliary_feat(auxiliary_img_list: List[torch.Tensor])Tuple[List[torch.Tensor], List[torch.Tensor]][源代码]

Extract and select features from data list at multiple scale.

参数

auxiliary_img_list (list[Tensor]) – List of data at different scales. In most cases, each dict contains: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore.

返回

rpn_feats (list[Tensor]): Features at multiple scale used

for rpn head training.

roi_feats (list[Tensor]): Features at multiple scale used

for roi head training.

返回类型

tuple

extract_feat(img: torch.Tensor)List[torch.Tensor][源代码]

Directly extract features from the backbone+neck.

forward(main_data: Dict = None, auxiliary_data: Dict = None, img: List[torch.Tensor] = None, img_metas: List[Dict] = None, return_loss: bool = True, **kwargs)Dict[源代码]

Calls either forward_train() or forward_test() depending on whether return_loss is True.

Note this setting will change the expected inputs. When return_loss=True, the input will be main and auxiliary data for training., and when resturn_loss=False, the input will be img and img_meta for testing.

参数
  • main_data (dict) – Used for forward_train(). Dict of data and data info, where each dict has: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore. Default: None.

  • auxiliary_data (dict) – Used for forward_train(). Dict of data and data info at multiple scales, where each key use different suffix to indicate different scale. For example, img_scale_i, img_metas_scale_i, gt_bboxes_scale_i, gt_labels_scale_i, gt_bboxes_ignore_scale_i, where i in range of 0 to number of scales. Default: None.

  • img (list[Tensor]) – Used for func:forward_test or forward_model_init(). List of tensors of shape (1, C, H, W). Typically these should be mean centered and std scaled. Default: None.

  • img_metas (list[dict]) – Used for func:forward_test or forward_model_init(). List of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys, see mmdet.datasets.pipelines.Collect. Default: None.

  • return_loss (bool) – If set True call forward_train(), otherwise call forward_test(). Default: True.

forward_train(main_data: Dict, auxiliary_data_list: List[Dict], **kwargs)Dict[源代码]
参数
  • main_data (dict) – In most cases, dict of main data contains: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore.

  • auxiliary_data_list (list[dict]) – List of data at different scales. In most cases, each dict contains: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore.

返回

a dictionary of loss components

返回类型

dict[str, Tensor]

train_step(data: Dict, optimizer: Union[object, Dict])Dict[源代码]

The iteration step during training.

This method defines an iteration step during training, except for the back propagation and optimizer updating, which are done in an optimizer hook. Note that in some complicated cases or models, the whole process including back propagation and optimizer updating is also defined in this method, such as GAN.

参数
  • data (dict) – The output of dataloader.

  • optimizer (torch.optim.Optimizer | dict) – The optimizer of runner is passed to train_step(). This argument is unused and reserved.

返回

It should contain at least 3 keys: loss, log_vars, num_samples.

  • loss is a tensor for back propagation, which can be a weighted sum of multiple losses.

  • log_vars contains all the variables to be sent to the

logger. - num_samples indicates the batch size (when the model is DDP, it means the batch size on each GPU), which is used for averaging the logs.

返回类型

dict

val_step(data: Dict, optimizer: Optional[Union[object, Dict]] = None)Dict[源代码]

The iteration step during validation.

This method shares the same signature as train_step(), but used during val epochs. Note that the evaluation after training epochs is not implemented with this method, but an evaluation hook.

class mmfewshot.detection.models.detectors.MetaRCNN(backbone: mmcv.utils.config.ConfigDict, neck: Optional[mmcv.utils.config.ConfigDict] = None, support_backbone: Optional[mmcv.utils.config.ConfigDict] = None, support_neck: Optional[mmcv.utils.config.ConfigDict] = None, rpn_head: Optional[mmcv.utils.config.ConfigDict] = None, roi_head: Optional[mmcv.utils.config.ConfigDict] = None, train_cfg: Optional[mmcv.utils.config.ConfigDict] = None, test_cfg: Optional[mmcv.utils.config.ConfigDict] = None, pretrained: Optional[mmcv.utils.config.ConfigDict] = None, init_cfg: Optional[mmcv.utils.config.ConfigDict] = None)[源代码]

Implementation of Meta R-CNN..

参数
  • backbone (dict) – Config of the backbone for query data.

  • neck (dict | None) – Config of the neck for query data and probably for support data. Default: None.

  • support_backbone (dict | None) – Config of the backbone for support data only. If None, support and query data will share same backbone. Default: None.

  • support_neck (dict | None) – Config of the neck for support data only. Default: None.

  • rpn_head (dict | None) – Config of rpn_head. Default: None.

  • roi_head (dict | None) – Config of roi_head. Default: None.

  • train_cfg (dict | None) – Training config. Useless in CenterNet, but we keep this variable for SingleStageDetector. Default: None.

  • test_cfg (dict | None) – Testing config of CenterNet. Default: None.

  • pretrained (str | None) – model pretrained path. Default: None.

  • init_cfg (dict | list[dict] | None) – Initialization config dict. Default: None

extract_support_feat(img)[源代码]

Extracting features from support data.

参数

img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

返回

Features of input image, each item with shape

(N, C, H, W).

返回类型

list[Tensor]

forward_model_init(img: torch.Tensor, img_metas: List[Dict], gt_bboxes: Optional[List[torch.Tensor]] = None, gt_labels: Optional[List[torch.Tensor]] = None, **kwargs)[源代码]

extract and save support features for model initialization.

参数
  • img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

  • img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • gt_bboxes (list[Tensor]) – Ground truth bboxes for each image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

  • gt_labels (list[Tensor]) – class indices corresponding to each box.

返回

A dict contains following keys:

  • gt_labels (Tensor): class indices corresponding to each

    feature.

  • res5_rois (list[Tensor]): roi features of res5 layer.

返回类型

dict

model_init()[源代码]

process the saved support features for model initialization.

simple_test(img: torch.Tensor, img_metas: List[Dict], proposals: Optional[List[torch.Tensor]] = None, rescale: bool = False)[源代码]

Test without augmentation.

参数
  • img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

  • img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • proposals (list[Tensor] | None) – override rpn proposals with custom proposals. Use when with_rpn is False. Default: None.

  • rescale (bool) – If True, return boxes in original image space.

返回

BBox results of each image and classes.

The outer list corresponds to each image. The inner list corresponds to each class.

返回类型

list[list[np.ndarray]]

class mmfewshot.detection.models.detectors.QuerySupportDetector(backbone: mmcv.utils.config.ConfigDict, neck: Optional[mmcv.utils.config.ConfigDict] = None, support_backbone: Optional[mmcv.utils.config.ConfigDict] = None, support_neck: Optional[mmcv.utils.config.ConfigDict] = None, rpn_head: Optional[mmcv.utils.config.ConfigDict] = None, roi_head: Optional[mmcv.utils.config.ConfigDict] = None, train_cfg: Optional[mmcv.utils.config.ConfigDict] = None, test_cfg: Optional[mmcv.utils.config.ConfigDict] = None, pretrained: Optional[mmcv.utils.config.ConfigDict] = None, init_cfg: Optional[mmcv.utils.config.ConfigDict] = None)[源代码]

Base class for two-stage detectors in query-support fashion.

Query-support detectors typically consisting of a region proposal network and a task-specific regression head. There are two pipelines for query and support data respectively.

参数
  • backbone (dict) – Config of the backbone for query data.

  • neck (dict | None) – Config of the neck for query data and probably for support data. Default: None.

  • support_backbone (dict | None) – Config of the backbone for support data only. If None, support and query data will share same backbone. Default: None.

  • support_neck (dict | None) – Config of the neck for support data only. Default: None.

  • rpn_head (dict | None) – Config of rpn_head. Default: None.

  • roi_head (dict | None) – Config of roi_head. Default: None.

  • train_cfg (dict | None) – Training config. Useless in CenterNet, but we keep this variable for SingleStageDetector. Default: None.

  • test_cfg (dict | None) – Testing config of CenterNet. Default: None.

  • pretrained (str | None) – model pretrained path. Default: None.

  • init_cfg (dict | list[dict] | None) – Initialization config dict. Default: None

aug_test(**kwargs)[源代码]

Test with augmentation.

extract_feat(img: torch.Tensor)List[torch.Tensor][源代码]

Extract features of query data.

参数

img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

返回

Features of query images.

返回类型

list[Tensor]

extract_query_feat(img: torch.Tensor)List[torch.Tensor][源代码]

Extract features of query data.

参数

img (Tensor) – Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled.

返回

Features of support images, each item with shape

(N, C, H, W).

返回类型

list[Tensor]

abstract extract_support_feat(img: torch.Tensor)[源代码]

Extract features of support data.

forward(query_data: Optional[Dict] = None, support_data: Optional[Dict] = None, img: Optional[List[torch.Tensor]] = None, img_metas: Optional[List[Dict]] = None, mode: typing_extensions.Literal[train, model_init, test] = 'train', **kwargs)Dict[源代码]

Calls one of (forward_train(), forward_test() and forward_model_init()) according to the mode. The inputs of forward function would change with the mode.

  • When mode is ‘train’, the input will be query and support data

for training.

  • When mode is ‘model_init’, the input will be support template

data at least including (img, img_metas).

  • When mode is ‘test’, the input will be test data at least

including (img, img_metas).

参数
  • query_data (dict) – Used for forward_train(). Dict of query data and data info where each dict has: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore. Default: None.

  • support_data (dict) – Used for forward_train(). Dict of support data and data info dict where each dict has: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore. Default: None.

  • img (list[Tensor]) – Used for func:forward_test or forward_model_init(). List of tensors of shape (1, C, H, W). Typically these should be mean centered and std scaled. Default: None.

  • img_metas (list[dict]) – Used for func:forward_test or forward_model_init(). List of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys, see mmdet.datasets.pipelines.Collect. Default: None.

  • mode (str) – Indicate which function to call. Options are ‘train’, ‘model_init’ and ‘test’. Default: ‘train’.

abstract forward_model_init(img: torch.Tensor, img_metas: List[Dict], gt_bboxes: Optional[List[torch.Tensor]] = None, gt_labels: Optional[List[torch.Tensor]] = None, **kwargs)[源代码]

extract and save support features for model initialization.

forward_train(query_data: Dict, support_data: Dict, proposals: Optional[List] = None, **kwargs)Dict[源代码]

Forward function for training.

参数
  • query_data (dict) – In most cases, dict of query data contains: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore.

  • support_data (dict) – In most cases, dict of support data contains: img, img_metas, gt_bboxes, gt_labels, gt_bboxes_ignore.

  • proposals (list) – Override rpn proposals with custom proposals. Use when with_rpn is False. Default: None.

返回

a dictionary of loss components

返回类型

dict[str, Tensor]

abstract model_init(**kwargs)[源代码]

process the saved support features for model initialization.

simple_test(img: torch.Tensor, img_metas: List[Dict], proposals: Optional[List[torch.Tensor]] = None, rescale: bool = False)[源代码]

Test without augmentation.

train_step(data: Dict, optimizer: Union[object, Dict])Dict[源代码]

The iteration step during training.

This method defines an iteration step during training, except for the back propagation and optimizer updating, which are done in an optimizer hook. Note that in some complicated cases or models, the whole process including back propagation and optimizer updating is also defined in this method, such as GAN. For most of query-support detectors, the batch size denote the batch size of query data.

参数
  • data (dict) – The output of dataloader.

  • optimizer (torch.optim.Optimizer | dict) – The optimizer of runner is passed to train_step(). This argument is unused and reserved.

返回

It should contain at least 3 keys: loss, log_vars,

num_samples.

  • loss is a tensor for back propagation, which can be a

weighted sum of multiple losses. - log_vars contains all the variables to be sent to the logger. - num_samples indicates the batch size (when the model is DDP, it means the batch size on each GPU), which is used for averaging the logs.

返回类型

dict

val_step(data: Dict, optimizer: Optional[Union[object, Dict]] = None)Dict[源代码]

The iteration step during validation.

This method shares the same signature as train_step(), but used during val epochs. Note that the evaluation after training epochs is not implemented with this method, but an evaluation hook.

class mmfewshot.detection.models.detectors.TFA(backbone, neck=None, rpn_head=None, roi_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None)[源代码]

Implementation of TFA

losses

class mmfewshot.detection.models.losses.SupervisedContrastiveLoss(temperature: float = 0.2, iou_threshold: float = 0.5, reweight_type: typing_extensions.Literal[none, exp, linear] = 'none', reduction: typing_extensions.Literal[none, mean, sum] = 'mean', loss_weight: float = 1.0)[源代码]

Supervised Contrastive LOSS.

This part of code is modified from https://github.com/MegviiDetection/FSCE.

参数
  • temperature (float) – A constant to be divided by consine similarity to enlarge the magnitude. Default: 0.2.

  • iou_threshold (float) – Consider proposals with higher credibility to increase consistency. Default: 0.5.

  • reweight_type (str) – Reweight function for contrastive loss. Options are (‘none’, ‘exp’, ‘linear’). Default: ‘none’.

  • reduction (str) – The method used to reduce the loss into a scalar. Default: ‘mean’. Options are “none”, “mean” and “sum”.

  • loss_weight (float) – Weight of loss. Default: 1.0.

forward(features: torch.Tensor, labels: torch.Tensor, ious: torch.Tensor, decay_rate: Optional[float] = None, weight: Optional[torch.Tensor] = None, avg_factor: Optional[int] = None, reduction_override: Optional[str] = None)torch.Tensor[源代码]

Forward function.

参数
  • features (tensor) – Shape of (N, K) where N is the number of features to be compared and K is the channels.

  • labels (tensor) – Shape of (N).

  • ious (tensor) – Shape of (N).

  • decay_rate (float | None) – The decay rate for total loss. Default: None.

  • weight (Tensor | None) – The weight of loss for each prediction with shape of (N). Default: None.

  • avg_factor (int | None) – Average factor that is used to average the loss. Default: None.

  • reduction_override (str | None) – The reduction method used to override the original reduction method of the loss. Options are “none”, “mean” and “sum”. Default: None.

返回

The calculated loss.

返回类型

Tensor

roi_heads

class mmfewshot.detection.models.roi_heads.ContrastiveBBoxHead(mlp_head_channels: int = 128, with_weight_decay: bool = False, loss_contrast: Dict = {'iou_threshold': 0.5, 'loss_weight': 1.0, 'reweight_type': 'none', 'temperature': 0.1, 'type': 'SupervisedContrastiveLoss'}, scale: int = 20, learnable_scale: bool = False, eps: float = 1e-05, *args, **kwargs)[源代码]

BBoxHead for FSCE.

参数
  • mlp_head_channels (int) – Output channels of contrast branch mlp. Default: 128.

  • with_weight_decay (bool) – Whether to decay loss weight. Default: False.

  • loss_contrast (dict) – Config of contrast loss.

  • scale (int) – Scaling factor of cls_score. Default: 20.

  • learnable_scale (bool) – Learnable global scaling factor. Default: False.

  • eps (float) – Constant variable to avoid division by zero.

forward(x: torch.Tensor)Tuple[torch.Tensor, torch.Tensor, torch.Tensor][源代码]

Forward function.

参数

x (Tensor) – Shape of (num_proposals, C, H, W).

返回

cls_score (Tensor): Cls scores, has shape

(num_proposals, num_classes).

bbox_pred (Tensor): Box energies / deltas, has shape

(num_proposals, 4).

contrast_feat (Tensor): Box features for contrast loss,

has shape (num_proposals, C).

返回类型

tuple

loss_contrast(contrast_feat: torch.Tensor, proposal_ious: torch.Tensor, labels: torch.Tensor, reduction_override: Optional[str] = None)Dict[源代码]

Loss for contract.

参数
  • contrast_feat (tensor) – BBox features with shape (N, C) used for contrast loss.

  • proposal_ious (tensor) – IoU between proposal and ground truth corresponding to each BBox features with shape (N).

  • labels (tensor) – Labels for each BBox features with shape (N).

  • reduction_override (str | None) – The reduction method used to override the original reduction method of the loss. Options are “none”, “mean” and “sum”. Default: None.

返回

The calculated loss.

返回类型

Dict

set_decay_rate(decay_rate: float)None[源代码]

Contrast loss weight decay hook will set the decay_rate according to iterations.

参数

decay_rate (float) – Decay rate for weight decay.

class mmfewshot.detection.models.roi_heads.ContrastiveRoIHead(bbox_roi_extractor=None, bbox_head=None, mask_roi_extractor=None, mask_head=None, shared_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None)[源代码]

RoI head for FSCE.

class mmfewshot.detection.models.roi_heads.CosineSimBBoxHead(scale: int = 20, learnable_scale: bool = False, eps: float = 1e-05, *args, **kwargs)[源代码]

BBOxHead for TFA.

The code is modified from the official implementation https://github.com/ucbdrive/few-shot-object-detection/

参数
  • scale (int) – Scaling factor of cls_score. Default: 20.

  • learnable_scale (bool) – Learnable global scaling factor. Default: False.

  • eps (float) – Constant variable to avoid division by zero.

forward(x: torch.Tensor)Tuple[torch.Tensor, torch.Tensor][源代码]

Forward function.

参数

x (Tensor) – Shape of (num_proposals, C, H, W).

返回

cls_score (Tensor): Cls scores, has shape

(num_proposals, num_classes).

bbox_pred (Tensor): Box energies / deltas, has shape

(num_proposals, 4).

返回类型

tuple

class mmfewshot.detection.models.roi_heads.FSDetViewRoIHead(aggregation_layer: Optional[Dict] = None, **kwargs)[源代码]

Roi head for FSDetView.

参数

aggregation_layer (dict) – Config of aggregation_layer. Default: None.

class mmfewshot.detection.models.roi_heads.MetaRCNNResLayer(*args, **kwargs)[源代码]

Shared resLayer for metarcnn and fsdetview.

It provides different forward logics for query and support images.

forward(x: torch.Tensor)torch.Tensor[源代码]

Forward function for query images.

参数

x (Tensor) – Features from backbone with shape (N, C, H, W).

返回

Shape of (N, C).

返回类型

Tensor

forward_support(x: torch.Tensor)torch.Tensor[源代码]

Forward function for support images.

参数

x (Tensor) – Features from backbone with shape (N, C, H, W).

返回

Shape of (N, C).

返回类型

Tensor

class mmfewshot.detection.models.roi_heads.MetaRCNNRoIHead(aggregation_layer: Optional[mmcv.utils.config.ConfigDict] = None, **kwargs)[源代码]

Roi head for MetaRCNN.

参数

aggregation_layer (ConfigDict) – Config of aggregation_layer. Default: None.

extract_query_roi_feat(feats: List[torch.Tensor], rois: torch.Tensor)torch.Tensor[源代码]

Extracting query BBOX features, which is used in both training and testing.

参数
  • feats (list[Tensor]) – List of query features, each item with shape (N, C, H, W).

  • rois (Tensor) – shape with (m, 5).

返回

RoI features with shape (N, C).

返回类型

Tensor

extract_support_feats(feats: List[torch.Tensor])List[torch.Tensor][源代码]

Forward support features through shared layers.

参数

feats (list[Tensor]) – List of support features, each item with shape (N, C, H, W).

返回

List of support features, each item

with shape (N, C).

返回类型

list[Tensor]

forward_train(query_feats: List[torch.Tensor], support_feats: List[torch.Tensor], proposals: List[torch.Tensor], query_img_metas: List[Dict], query_gt_bboxes: List[torch.Tensor], query_gt_labels: List[torch.Tensor], support_gt_labels: List[torch.Tensor], query_gt_bboxes_ignore: Optional[List[torch.Tensor]] = None, **kwargs)Dict[源代码]

Forward function for training.

参数
  • query_feats (list[Tensor]) – List of query features, each item with shape (N, C, H, W).

  • support_feats (list[Tensor]) – List of support features, each item with shape (N, C, H, W).

  • proposals (list[Tensor]) – List of region proposals with positive and negative pairs.

  • query_img_metas (list[dict]) – List of query image info dict where each dict has: ‘img_shape’, ‘scale_factor’, ‘flip’, and may also contain ‘filename’, ‘ori_shape’, ‘pad_shape’, and ‘img_norm_cfg’. For details on the values of these keys see mmdet/datasets/pipelines/formatting.py:Collect.

  • query_gt_bboxes (list[Tensor]) – Ground truth bboxes for each query image, each item with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

  • query_gt_labels (list[Tensor]) – Class indices corresponding to each box of query images, each item with shape (num_gts).

  • support_gt_labels (list[Tensor]) – Class indices corresponding to each box of support images, each item with shape (1).

  • query_gt_bboxes_ignore (list[Tensor] | None) – Specify which bounding boxes can be ignored when computing the loss. Default: None.

返回

A dictionary of loss components

返回类型

dict[str, Tensor]

simple_test(query_feats: List[torch.Tensor], support_feats_dict: Dict, proposal_list: List[torch.Tensor], query_img_metas: List[Dict], rescale: bool = False)List[List[numpy.ndarray]][源代码]

Test without augmentation.

参数
  • query_feats (list[Tensor]) – Features of query image, each item with shape (N, C, H, W).

  • support_feats_dict (dict[int, Tensor]) – used for inference only, each key is the class id and value is the support template features with shape (1, C).

  • proposal_list (list[Tensors]) – list of region proposals.

  • query_img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • rescale (bool) – Whether to rescale the results. Default: False.

返回

BBox results of each image and classes.

The outer list corresponds to each image. The inner list corresponds to each class.

返回类型

list[list[np.ndarray]]

simple_test_bboxes(query_feats: List[torch.Tensor], support_feats_dict: Dict, query_img_metas: List[Dict], proposals: List[torch.Tensor], rcnn_test_cfg: mmcv.utils.config.ConfigDict, rescale: bool = False)Tuple[List[torch.Tensor], List[torch.Tensor]][源代码]

Test only det bboxes without augmentation.

参数
  • query_feats (list[Tensor]) – Features of query image, each item with shape (N, C, H, W).

  • support_feats_dict (dict[int, Tensor]) – used for inference only, each key is the class id and value is the support template features with shape (1, C).

  • query_img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • proposals (list[Tensor]) – Region proposals.

  • (obj (rcnn_test_cfg) – ConfigDict): test_cfg of R-CNN.

  • rescale (bool) – If True, return boxes in original image space. Default: False.

返回

Each tensor in first list

with shape (num_boxes, 4) and with shape (num_boxes, ) in second list. The length of both lists should be equal to batch_size.

返回类型

tuple[list[Tensor], list[Tensor]]

class mmfewshot.detection.models.roi_heads.MultiRelationBBoxHead(patch_relation: bool = True, local_correlation: bool = True, global_relation: bool = True, *args, **kwargs)[源代码]

BBox head for Attention RPN.

参数
  • patch_relation (bool) – Whether use patch_relation head for classification. Following the official implementation, patch_relation always be True, because only patch relation head contain regression head. Default: True.

  • local_correlation (bool) – Whether use local_correlation head for classification. Default: True.

  • global_relation (bool) – Whether use global_relation head for classification. Default: True.

forward(query_feat: torch.Tensor, support_feat: torch.Tensor)Tuple[torch.Tensor, torch.Tensor][源代码]

Forward function.

参数
  • query_feat (Tensor) – Shape of (num_proposals, C, H, W).

  • support_feat (Tensor) – Shape of (1, C, H, W).

返回

cls_score (Tensor): Cls scores, has shape

(num_proposals, num_classes).

bbox_pred (Tensor): Box energies / deltas, has shape

(num_proposals, 4).

返回类型

tuple

loss(cls_scores: torch.Tensor, bbox_preds: torch.Tensor, rois: torch.Tensor, labels: torch.Tensor, label_weights: torch.Tensor, bbox_targets: torch.Tensor, bbox_weights: torch.Tensor, num_pos_pair_samples: int, reduction_override: Optional[str] = None, sample_fractions: Sequence[Union[int, float]] = (1, 2, 1))Dict[源代码]

Compute losses of the head.

参数
  • cls_scores (Tensor) – Box scores with shape of (num_proposals, num_classes)

  • bbox_preds (Tensor) – Box energies / deltas with shape of (num_proposals, num_classes * 4)

  • rois (Tensor) – shape (N, 4) or (N, 5)

  • labels (Tensor) – Labels of proposals with shape (num_proposals).

  • label_weights (Tensor) – Label weights of proposals with shape (num_proposals).

  • bbox_targets (Tensor) – BBox regression targets of each proposal weight with shape (num_proposals, num_classes * 4).

  • bbox_weights (Tensor) – BBox regression loss weights of each proposal with shape (num_proposals, num_classes * 4).

  • num_pos_pair_samples (int) – Number of samples from positive pairs.

  • reduction_override (str | None) – The reduction method used to override the original reduction method of the loss. Options are “none”, “mean” and “sum”. Default: None.

  • sample_fractions (Sequence[int | float]) – Fractions of positive samples, negative samples from positive pair, negative samples from negative pair. Default: (1, 2, 1).

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

class mmfewshot.detection.models.roi_heads.MultiRelationRoIHead(num_support_ways: int = 2, num_support_shots: int = 5, sample_fractions: Sequence[Union[int, float]] = (1, 2, 1), **kwargs)[源代码]

Roi head for AttentionRPN.

参数
  • num_support_ways (int) – Number of sampled classes (pos + neg).

  • num_support_shots (int) – Number of shot for each classes.

  • sample_fractions (Sequence[int | float]) – Fractions of positive samples, negative samples from positive pair, negative samples from negative pair. Default: (1, 2, 1).

extract_roi_feat(feats: List[torch.Tensor], rois: torch.Tensor)torch.Tensor[源代码]

Extract BBOX feature function used in both training and testing.

参数
  • feats (list[Tensor]) – Features from backbone, each item with shape (N, C, W, H).

  • rois (Tensor) – shape (num_proposals, 5).

返回

Roi features with shape (num_proposals, C).

返回类型

Tensor

forward_train(query_feats: List[torch.Tensor], support_feats: List[torch.Tensor], proposals: List[torch.Tensor], query_img_metas: List[Dict], query_gt_bboxes: List[torch.Tensor], query_gt_labels: List[torch.Tensor], support_gt_bboxes: List[torch.Tensor], query_gt_bboxes_ignore: Optional[List[torch.Tensor]] = None, **kwargs)Dict[源代码]

All arguments excepted proposals are passed in tuple of (query, support).

参数
  • query_feats (list[Tensor]) – List of query features, each item with shape (N, C, H, W).

  • support_feats (list[Tensor]) – List of support features, each item with shape (N, C, H, W).

  • proposals (list[Tensor]) – List of region proposals with positive and negative query-support pairs.

  • query_img_metas (list[dict]) – List of query image info dict where each dict has: ‘img_shape’, ‘scale_factor’, ‘flip’, and may also contain ‘filename’, ‘ori_shape’, ‘pad_shape’, and ‘img_norm_cfg’. For details on the values of these keys see mmdet/datasets/pipelines/formatting.py:Collect.

  • query_gt_bboxes (list[Tensor]) – Ground truth bboxes for each query image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

  • query_gt_labels (list[Tensor]) – Class indices corresponding to each bbox from query image.

  • support_gt_bboxes (list[Tensor]) – Ground truth bboxes for each support image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.

  • query_gt_bboxes_ignore (None | list[Tensor]) – Specify which bounding boxes from query image can be ignored when computing the loss. Default: None.

返回

A dictionary of loss components.

返回类型

dict[str, Tensor]

simple_test(query_feats: List[torch.Tensor], support_feat: torch.Tensor, proposals: List[torch.Tensor], query_img_metas: List[Dict], rescale: bool = False)List[List[numpy.ndarray]][源代码]

Test without augmentation.

参数
  • query_feats (list[Tensor]) – List of query features, each item with shape (N, C, H, W).

  • support_feat (Tensor) – Support features with shape (N, C, H, W).

  • proposals (Tensor or list[Tensor]) – list of region proposals.

  • query_img_metas (list[dict]) – list of query image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • proposals – Region proposals. Default: None.

  • rescale (bool) – Whether to rescale the results. Default: False.

返回

BBox results of each image and classes.

The outer list corresponds to each image. The inner list corresponds to each class.

返回类型

list[list[np.ndarray]]

simple_test_bboxes(query_feats: List[torch.Tensor], support_feat: torch.Tensor, query_img_metas: List[Dict], proposals: List[torch.Tensor], rcnn_test_cfg: mmcv.utils.config.ConfigDict, rescale: bool = False)Tuple[List[torch.Tensor], List[torch.Tensor]][源代码]

Test only det bboxes without augmentation.

参数
  • query_feats (list[Tensor]) – List of query features, each item with shape (N, C, H, W).

  • support_feat (Tensor) – Support feature with shape (N, C, H, W).

  • query_img_metas (list[dict]) – list of image info dict where each dict has: img_shape, scale_factor, flip, and may also contain filename, ori_shape, pad_shape, and img_norm_cfg. For details on the values of these keys see mmdet.datasets.pipelines.Collect.

  • proposals (list[Tensor]) – Region proposals.

  • (obj (rcnn_test_cfg) – ConfigDict): test_cfg of R-CNN.

  • rescale (bool) – If True, return boxes in original image space. Default: False.

返回

BBox of shape [N, num_bboxes, 5]

and class labels of shape [N, num_bboxes].

返回类型

tuple[Tensor, Tensor]

class mmfewshot.detection.models.roi_heads.TwoBranchRoIHead(bbox_roi_extractor=None, bbox_head=None, mask_roi_extractor=None, mask_head=None, shared_head=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None)[源代码]

RoI head for MPSR.

forward_auxiliary_train(feats: Tuple[torch.Tensor], gt_labels: List[torch.Tensor])Dict[源代码]

Forward function and calculate loss for auxiliary data in training.

参数
  • feats (tuple[Tensor]) – List of features at multiple scales, each is a 4D-tensor.

  • gt_labels (list[Tensor]) – List of class indices corresponding to each features, each is a 4D-tensor.

返回

a dictionary of loss components

返回类型

dict[str, Tensor]

utils

detection.utils

class mmfewshot.detection.utils.ContrastiveLossDecayHook(decay_steps: Sequence[int], decay_rate: float = 0.5)

Hook for contrast loss weight decay used in FSCE.

参数
  • decay_steps (list[int] | tuple[int]) – Each item in the list is the step to decay the loss weight.

  • decay_rate (float) – Decay rate. Default: 0.5.

mmfewshot.utils

class mmfewshot.utils.DistributedInfiniteGroupSampler(dataset: Iterable, samples_per_gpu: int = 1, num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, shuffle: bool = True)[源代码]

Similar to InfiniteGroupSampler but in distributed version.

The length of sampler is set to the actual length of dataset, thus the length of dataloader is still determined by the dataset. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py

参数
  • dataset (Iterable) – The dataset.

  • samples_per_gpu (int) – Number of training samples on each GPU, i.e., batch size of each GPU. Default: 1.

  • num_replicas (int | None) – Number of processes participating in distributed training. Default: None.

  • rank (int | None) – Rank of current process. Default: None.

  • seed (int) – Random seed. Default: 0.

  • shuffle (bool) – Whether shuffle the indices of a dummy epoch, it should be noted that shuffle can not guarantee that you can generate sequential indices because it need to ensure that all indices in a batch is in a group. Default: True.

class mmfewshot.utils.DistributedInfiniteSampler(dataset: Iterable, num_replicas: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, shuffle: bool = True)[源代码]

Similar to InfiniteSampler but in distributed version.

The length of sampler is set to the actual length of dataset, thus the length of dataloader is still determined by the dataset. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py

参数
  • dataset (Iterable) – The dataset.

  • num_replicas (int | None) – Number of processes participating in distributed training. Default: None.

  • rank (int | None) – Rank of current process. Default: None.

  • seed (int) – Random seed. Default: 0.

  • shuffle (bool) – Whether shuffle the dataset or not. Default: True.

class mmfewshot.utils.InfiniteEpochBasedRunner(model, batch_processor=None, optimizer=None, work_dir=None, logger=None, meta=None, max_iters=None, max_epochs=None)[源代码]

Epoch-based Runner supports dataloader with InfiniteSampler.

The workers of dataloader will re-initialize, when the iterator of dataloader is created. InfiniteSampler is designed to avoid these time consuming operations, since the iterator with InfiniteSampler will never reach the end.

class mmfewshot.utils.InfiniteGroupSampler(dataset: Iterable, samples_per_gpu: int = 1, seed: int = 0, shuffle: bool = True)[源代码]

Similar to InfiniteSampler, but all indices in a batch should be in the same group of flag.

The length of sampler is set to the actual length of dataset, thus the length of dataloader is still determined by the dataset. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py

参数
  • dataset (Iterable) – The dataset.

  • samples_per_gpu (int) – Number of training samples on each GPU, i.e., batch size of each GPU. Default: 1.

  • seed (int) – Random seed. Default: 0.

  • shuffle (bool) – Whether shuffle the indices of a dummy epoch, it should be noted that shuffle can not guarantee that you can generate sequential indices because it need to ensure that all indices in a batch is in a group. Default: True.

class mmfewshot.utils.InfiniteSampler(dataset: Iterable, seed: int = 0, shuffle: bool = True)[源代码]

Return a infinite stream of index.

The length of sampler is set to the actual length of dataset, thus the length of dataloader is still determined by the dataset. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py

参数
  • dataset (Iterable) – The dataset.

  • seed (int) – Random seed. Default: 0.

  • shuffle (bool) – Whether shuffle the dataset or not. Default: True.

mmfewshot.utils.compat_cfg(cfg)[源代码]

This function would modify some filed to keep the compatibility of config.

For example, it will move some args which will be deprecated to the correct fields.

mmfewshot.utils.local_numpy_seed(seed: Optional[int] = None)None[源代码]

Run numpy codes with a local random seed.

If seed is None, the default random state will be used.

mmfewshot.utils.multi_pipeline_collate_fn(batch, samples_per_gpu: int = 1)[源代码]

Puts each data field into a tensor/DataContainer with outer dimension batch size. This is designed to support the case that the __getitem__() of dataset return more than one images, such as query_support dataloader. The main difference with the collate_fn() in mmcv is it can process list[list[DataContainer]].

Extend default_collate to add support for :type:`~mmcv.parallel.DataContainer`. There are 3 cases:

  1. cpu_only = True, e.g., meta data.

  2. cpu_only = False, stack = True, e.g., images tensors.

  3. cpu_only = False, stack = False, e.g., gt bboxes.

:param batch (list[list[mmcv.parallel.DataContainer]] |: list[mmcv.parallel.DataContainer]): Data of

single batch.

参数

samples_per_gpu (int) – The number of samples of single GPU.

mmfewshot.utils.sync_random_seed(seed=None, device='cuda')[源代码]

Propagating the seed of rank 0 to all other ranks.

Make sure different ranks share the same seed. All workers must call this function, otherwise it will deadlock. This method is generally used in DistributedSampler, because the seed should be identical across all processes in the distributed group. In distributed sampling, different ranks should sample non-overlapped data in the dataset. Therefore, this function is used to make sure that each rank shuffles the data indices in the same order based on the same seed. Then different ranks could use different indices to select non-overlapped data from the same data list. :param seed: The seed. Default to None. :type seed: int, Optional :param device: The device where the seed will be put on.

Default to ‘cuda’.

返回

Seed to be used.

返回类型

int

Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.