Shortcuts

Source code for mmfewshot.classification.models.heads.prototype_head

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, List

import torch
import torch.nn.functional as F
from mmcls.models.builder import HEADS
from torch import Tensor

from mmfewshot.classification.datasets import label_wrapper
from .base_head import BaseFewShotHead


[docs]@HEADS.register_module() class PrototypeHead(BaseFewShotHead): """Classification head for `ProtoNet. <https://arxiv.org/abs/1703.05175>`_. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # used in meta testing self.support_feats_list = [] self.support_labels_list = [] self.class_ids = None self.prototype_feats = None
[docs] def forward_train(self, support_feats: Tensor, support_labels: Tensor, query_feats: Tensor, query_labels: Tensor, **kwargs) -> Dict: """Forward training data. Args: support_feats (Tensor): Features of support data with shape (N, C). support_labels (Tensor): Labels of support data with shape (N). query_feats (Tensor): Features of query data with shape (N, C). query_labels (Tensor): Labels of query data with shape (N). Returns: dict[str, Tensor]: A dictionary of loss components. """ class_ids = torch.unique(support_labels).cpu().tolist() prototype_feats = [ support_feats[support_labels == class_id].mean(0, keepdim=True) for class_id in class_ids ] prototype_feats = torch.cat(prototype_feats, dim=0) cls_scores = -1 * torch.cdist( query_feats.unsqueeze(0), prototype_feats.unsqueeze(0)).squeeze(0) query_labels = label_wrapper(query_labels, class_ids) losses = self.loss(cls_scores, query_labels) return losses
[docs] def forward_support(self, x: Tensor, gt_label: Tensor, **kwargs) -> None: """Forward support data in meta testing.""" self.support_feats_list.append(x) self.support_labels_list.append(gt_label)
[docs] def forward_query(self, x: Tensor, **kwargs) -> List: """Forward query data in meta testing.""" assert self.prototype_feats is not None cls_scores = -1 * torch.cdist( x.unsqueeze(0), self.prototype_feats.unsqueeze(0)).squeeze(0) pred = F.softmax(cls_scores, dim=1) pred = list(pred.detach().cpu().numpy()) return pred
[docs] def before_forward_support(self) -> None: """Used in meta testing. This function will be called before model forward support data during meta testing. """ # reset prototype features for testing new task self.support_feats_list.clear() self.support_labels_list.clear() self.prototype_feats = None self.class_ids = None
[docs] def before_forward_query(self) -> None: """Used in meta testing. This function will be called before model forward query data during meta testing. """ feats = torch.cat(self.support_feats_list, dim=0) labels = torch.cat(self.support_labels_list, dim=0) self.class_ids, _ = torch.unique(labels).sort() prototype_feats = [ feats[labels == class_id].mean(0, keepdim=True) for class_id in self.class_ids ] self.prototype_feats = torch.cat(prototype_feats, dim=0) if max(self.class_ids) + 1 != len(self.class_ids): warnings.warn( f'the max class id is {max(self.class_ids)}, while ' f'the number of different number of classes is ' f'{len(self.class_ids)}, it will cause label ' f'mismatching problem.', UserWarning)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.