Shortcuts

mmfewshot.detection.datasets.dataloader_wrappers 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Iterator

from torch.utils.data import DataLoader


[文档]class NWayKShotDataloader: """A dataloader wrapper. It Create a iterator to generate query and support batch simultaneously. Each batch contains query data and support data, and the lengths are batch_size and (num_support_ways * num_support_shots) respectively. Args: query_data_loader (DataLoader): DataLoader of query dataset support_data_loader (DataLoader): DataLoader of support datasets. """ def __init__(self, query_data_loader: DataLoader, support_data_loader: DataLoader) -> None: self.dataset = query_data_loader.dataset self.sampler = query_data_loader.sampler self.query_data_loader = query_data_loader self.support_data_loader = support_data_loader def __iter__(self) -> Iterator: # if infinite sampler is used, this part of code only run once self.query_iter = iter(self.query_data_loader) self.support_iter = iter(self.support_data_loader) return self def __next__(self) -> Dict: # call query and support iterator query_data = self.query_iter.next() support_data = self.support_iter.next() return {'query_data': query_data, 'support_data': support_data} def __len__(self) -> int: return len(self.query_data_loader)
class TwoBranchDataloader: """A dataloader wrapper. It Create a iterator to iterate two different dataloader simultaneously. Note that `TwoBranchDataloader` dose not support `EpochBasedRunner` and the length of dataloader is decided by main dataset. Args: main_data_loader (DataLoader): DataLoader of main dataset. auxiliary_data_loader (DataLoader): DataLoader of auxiliary dataset. """ def __init__(self, main_data_loader: DataLoader, auxiliary_data_loader: DataLoader) -> None: self.dataset = main_data_loader.dataset self.main_data_loader = main_data_loader self.auxiliary_data_loader = auxiliary_data_loader def __iter__(self) -> Iterator: # if infinite sampler is used, this part of code only run once self.main_iter = iter(self.main_data_loader) self.auxiliary_iter = iter(self.auxiliary_data_loader) return self def __next__(self) -> Dict: # The iterator actually has infinite length. Note that it can NOT # be used in `EpochBasedRunner`, because the `EpochBasedRunner` will # enumerate the dataloader forever. try: main_data = next(self.main_iter) except StopIteration: self.main_iter = iter(self.main_data_loader) main_data = next(self.main_iter) try: auxiliary_data = next(self.auxiliary_iter) except StopIteration: self.auxiliary_iter = iter(self.auxiliary_data_loader) auxiliary_data = next(self.auxiliary_iter) return {'main_data': main_data, 'auxiliary_data': auxiliary_data} def __len__(self) -> int: return len(self.main_data_loader)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.