Shortcuts

Source code for mmfewshot.classification.datasets.utils

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union

import numpy as np
import torch
from torch import Tensor


[docs]def label_wrapper(labels: Union[Tensor, np.ndarray, List], class_ids: List[int]) -> Union[Tensor, np.ndarray, list]: """Map input labels into range of 0 to numbers of classes-1. It is usually used in the meta testing phase, in which the class ids are random sampled and discontinuous. Args: labels (Tensor | np.ndarray | list): The labels to be wrapped. class_ids (list[int]): All class ids of labels. Returns: (Tensor | np.ndarray | list): Same type as the input labels. """ class_id_map = {class_id: i for i, class_id in enumerate(class_ids)} if isinstance(labels, torch.Tensor): wrapped_labels = torch.tensor( [class_id_map[label.item()] for label in labels]) wrapped_labels = wrapped_labels.type_as(labels).to(labels.device) elif isinstance(labels, np.ndarray): wrapped_labels = np.array([class_id_map[label] for label in labels]) wrapped_labels = wrapped_labels.astype(labels.dtype) elif isinstance(labels, (tuple, list)): wrapped_labels = [class_id_map[label] for label in labels] else: raise TypeError('only support torch.Tensor, np.ndarray and list') return wrapped_labels
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.