• Docs >
  • Module code >
  • mmfewshot.detection.models.roi_heads.shared_heads.meta_rcnn_res_layer
Shortcuts

Source code for mmfewshot.detection.models.roi_heads.shared_heads.meta_rcnn_res_layer

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmdet.models.builder import SHARED_HEADS
from mmdet.models.roi_heads import ResLayer
from torch import Tensor


[docs]@SHARED_HEADS.register_module() class MetaRCNNResLayer(ResLayer): """Shared resLayer for metarcnn and fsdetview. It provides different forward logics for query and support images. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_pool = nn.MaxPool2d(2) self.sigmoid = nn.Sigmoid()
[docs] def forward(self, x: Tensor) -> Tensor: """Forward function for query images. Args: x (Tensor): Features from backbone with shape (N, C, H, W). Returns: Tensor: Shape of (N, C). """ res_layer = getattr(self, f'layer{self.stage + 1}') out = res_layer(x) out = out.mean(3).mean(2) return out
[docs] def forward_support(self, x: Tensor) -> Tensor: """Forward function for support images. Args: x (Tensor): Features from backbone with shape (N, C, H, W). Returns: Tensor: Shape of (N, C). """ x = self.max_pool(x) res_layer = getattr(self, f'layer{self.stage + 1}') out = res_layer(x) out = self.sigmoid(out) out = out.mean(3).mean(2) return out
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.