
mmfewshot.classification.models.heads.neg_margin_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.models.builder import HEADS
from mmcls.models.heads import ClsHead
from torch import Tensor

[文档]@HEADS.register_module() class NegMarginHead(ClsHead): """Classification head for `NegMargin <>`_. Args: num_classes (int): Number of categories. in_channels (int): Number of channels in the input feature map. temperature (float): Scaling factor of `cls_score`. Default: 30.0. margin (float): Margin of `cls_score`. Default: 0.0. metric_type (str): The way to calculate similarity. Options:['cosine', 'softmax']. Default: 'cosine' """ def __init__(self, num_classes: int, in_channels: int, temperature: float = 30.0, margin: float = 0.0, metric_type: str = 'cosine', *args, **kwargs) -> None: super().__init__(*args, **kwargs) assert num_classes > 0, f'num_classes={num_classes} ' \ f'must be a positive integer' assert margin <= 0, f'margin = {margin} should <= 0' assert metric_type in ['cosine', 'softmax'] self.num_classes = num_classes self.in_channels = in_channels self.metric_type = metric_type self.temperature = temperature self.margin = margin self.init_layers() def init_layers(self) -> None: self.weight = nn.Parameter( torch.FloatTensor(self.num_classes, self.in_channels)) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
[文档] def forward_train(self, x: Tensor, gt_label: Tensor, **kwargs) -> Dict: """Forward training data.""" if self.metric_type == 'cosine': similarity = F.linear(F.normalize(x), F.normalize(self.weight)) elif self.metric_type == 'softmax': similarity = F.linear(x, self.weight) similarity -= similarity.min(dim=1, keepdim=True)[0] else: raise ValueError(f'metric type {self.metric_type} not supported') # for training and negative margin metric, add margin to the similarity one_hot_mask = torch.zeros((gt_label.size(0), self.num_classes), dtype=torch.uint8).to(gt_label.device) one_hot_mask = one_hot_mask.scatter_(1, gt_label.unsqueeze(1), 1) similarity = torch.where(one_hot_mask, similarity - self.margin, similarity) cls_score = similarity * self.temperature losses = self.loss(cls_score, gt_label) return losses
[文档] def forward_support(self, x: Tensor, gt_label: Tensor, **kwargs) -> Dict: """Forward support data in meta testing.""" return self.forward_train(x, gt_label, **kwargs)
[文档] def forward_query(self, x: Tensor, **kwargs) -> List: """Forward query data in meta testing.""" if self.metric_type == 'cosine': similarity = F.linear(F.normalize(x), F.normalize(self.weight)) elif self.metric_type == 'softmax': similarity = F.linear(x, self.weight) similarity -= similarity.min(dim=1, keepdim=True)[0] else: raise ValueError(f'metric type {self.metric_type} not supported') cls_score = similarity * self.temperature pred = F.softmax(cls_score, dim=1) pred = list(pred.detach().cpu().numpy()) return pred
[文档] def before_forward_support(self) -> None: """Used in meta testing. This function will be called before model forward support data during meta testing. """ self.init_layers() self.train()
[文档] def before_forward_query(self) -> None: """Used in meta testing. This function will be called before model forward query data during meta testing. """ self.eval()
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.