
mmfewshot.detection.models.backbones.resnet_with_meta_conv 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

from mmcv.cnn import build_conv_layer
from mmdet.models import ResNet
from mmdet.models.builder import BACKBONES
from torch import Tensor

[文档]@BACKBONES.register_module() class ResNetWithMetaConv(ResNet): """ResNet with `meta_conv` to handle different inputs in metarcnn and fsdetview. When input with shape (N, 3, H, W) from images, the network will use `conv1` as regular ResNet. When input with shape (N, 4, H, W) from (image + mask) the network will replace `conv1` with `meta_conv` to handle additional channel. """ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.meta_conv = build_conv_layer( self.conv_cfg, # from config of ResNet 4, 64, kernel_size=7, stride=2, padding=3, bias=False)
[文档] def forward(self, x: Tensor, use_meta_conv: bool = False) -> Tuple[Tensor]: """Forward function. When input with shape (N, 3, H, W) from images, the network will use `conv1` as regular ResNet. When input with shape (N, 4, H, W) from (image + mask) the network will replace `conv1` with `meta_conv` to handle additional channel. Args: x (Tensor): Tensor with shape (N, 3, H, W) from images or (N, 4, H, W) from (images + masks). use_meta_conv (bool): If set True, forward input tensor with `meta_conv` which require tensor with shape (N, 4, H, W). Otherwise, forward input tensor with `conv1` which require tensor with shape (N, 3, H, W). Default: False. Returns: tuple[Tensor]: Tuple of features, each item with shape (N, C, H, W). """ if use_meta_conv: x = self.meta_conv(x) else: x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i in self.out_indices: outs.append(x) return tuple(outs)
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.