Shortcuts

Source code for mmfewshot.detection.models.losses.supervised_contrastive_loss

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.nn as nn
from mmdet.models import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss
from torch import Tensor
from typing_extensions import Literal


[docs]@LOSSES.register_module() class SupervisedContrastiveLoss(nn.Module): """`Supervised Contrastive LOSS <https://arxiv.org/abs/2004.11362>`_. This part of code is modified from https://github.com/MegviiDetection/FSCE. Args: temperature (float): A constant to be divided by consine similarity to enlarge the magnitude. Default: 0.2. iou_threshold (float): Consider proposals with higher credibility to increase consistency. Default: 0.5. reweight_type (str): Reweight function for contrastive loss. Options are ('none', 'exp', 'linear'). Default: 'none'. reduction (str): The method used to reduce the loss into a scalar. Default: 'mean'. Options are "none", "mean" and "sum". loss_weight (float): Weight of loss. Default: 1.0. """ def __init__(self, temperature: float = 0.2, iou_threshold: float = 0.5, reweight_type: Literal['none', 'exp', 'linear'] = 'none', reduction: Literal['none', 'mean', 'sum'] = 'mean', loss_weight: float = 1.0) -> None: super().__init__() assert temperature > 0, 'temperature should be a positive number.' self.temperature = temperature self.iou_threshold = iou_threshold self.reweight_func = self._get_reweight_func(reweight_type) self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, features: Tensor, labels: Tensor, ious: Tensor, decay_rate: Optional[float] = None, weight: Optional[Tensor] = None, avg_factor: Optional[int] = None, reduction_override: Optional[str] = None) -> Tensor: """Forward function. Args: features (tensor): Shape of (N, K) where N is the number of features to be compared and K is the channels. labels (tensor): Shape of (N). ious (tensor): Shape of (N). decay_rate (float | None): The decay rate for total loss. Default: None. weight (Tensor | None): The weight of loss for each prediction with shape of (N). Default: None. avg_factor (int | None): Average factor that is used to average the loss. Default: None. reduction_override (str | None): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Default: None. Returns: Tensor: The calculated loss. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) loss_weight = self.loss_weight if decay_rate is not None: loss_weight = self.loss_weight * decay_rate assert features.shape[0] == labels.shape[0] == ious.shape[0] if len(labels.shape) == 1: labels = labels.reshape(-1, 1) # mask with shape [N, N], mask_{i, j}=1 # if sample i and sample j have the same label label_mask = torch.eq(labels, labels.T).float().to(features.device) similarity = torch.div( torch.matmul(features, features.T), self.temperature) # for numerical stability sim_row_max, _ = torch.max(similarity, dim=1, keepdim=True) similarity = similarity - sim_row_max.detach() # mask out self-contrastive logits_mask = torch.ones_like(similarity) logits_mask.fill_diagonal_(0) exp_sim = torch.exp(similarity) * logits_mask log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True)) per_label_log_prob = (log_prob * logits_mask * label_mask).sum(1) / label_mask.sum(1) keep = ious >= self.iou_threshold if keep.sum() == 0: # return zero loss return per_label_log_prob.sum() * 0 per_label_log_prob = per_label_log_prob[keep] loss = -per_label_log_prob coefficient = self.reweight_func(ious) coefficient = coefficient[keep] if weight is not None: weight = weight[keep] loss = loss * coefficient loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss_weight * loss
@staticmethod def _get_reweight_func( reweight_type: Literal['none', 'exp', 'linear'] = 'none') -> callable: """Return corresponding reweight function according to `reweight_type`. Args: reweight_type (str): Reweight function for contrastive loss. Options are ('none', 'exp', 'linear'). Default: 'none'. Returns: callable: Used for reweight loss. """ assert reweight_type in ('none', 'exp', 'linear'), \ f'not support `reweight_type` {reweight_type}.' if reweight_type == 'none': def trivial(iou): return torch.ones_like(iou) return trivial elif reweight_type == 'linear': def linear(iou): return iou return linear elif reweight_type == 'exp': def exp_decay(iou): return torch.exp(iou) - 1 return exp_decay
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.