
mmfewshot.classification.models.backbones.conv4 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence

import torch.nn as nn
from mmcls.models.builder import BACKBONES
from torch import Tensor

class ConvBlock(nn.Module):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 is_pooling: bool = True,
                 padding: int = 1) -> None:
        self.in_channels = in_channels
        self.out_channels = out_channels
        layers = [
            nn.Conv2d(in_channels, out_channels, 3, padding=padding),
        if is_pooling:
        self.layers = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        out = self.layers(x)
        return out

[文档]@BACKBONES.register_module() class ConvNet(nn.Module): """Simple ConvNet. Args: depth (int): The number of `ConvBlock`. pooling_blocks (Sequence[int]): Indicate which block to use 2x2 max pooling. padding_blocks (Sequence[int]): Indicate which block to use conv layer with padding. flatten (bool): Whether to flatten features from (N, C, H, W) to (N, C*H*W). Default: True. """ def __init__(self, depth: int, pooling_blocks: Sequence[int], padding_blocks: Sequence[int], flatten: bool = True) -> None: super().__init__() layers = [] for i in range(depth): in_channels = 3 if i == 0 else 64 out_channels = 64 # only pooling for fist 4 layers layers.append( ConvBlock( in_channels, out_channels, is_pooling=(i in pooling_blocks), padding=1 if i in padding_blocks else 0)) self.flatten = flatten self.layers = nn.Sequential(*layers) self.init_weights()
[文档] def forward(self, x: Tensor) -> Tensor: out = self.layers(x) # (N, 64, 5, 5) if self.flatten: out = out.view(out.size(0), -1) # (N, 1600) return out
def init_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight, mode='fan_out', nonlinearity='leaky_relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
[文档]@BACKBONES.register_module() class Conv4(ConvNet): def __init__(self, depth: int = 4, pooling_blocks: Sequence[int] = (0, 1, 2, 3), padding_blocks: Sequence[int] = (0, 1, 2, 3), flatten: bool = True) -> None: super().__init__( depth=depth, pooling_blocks=pooling_blocks, padding_blocks=padding_blocks, flatten=flatten)
@BACKBONES.register_module() class Conv4NoPool(ConvNet): """Used for RelationNet.""" def __init__(self, depth: int = 4, pooling_blocks: Sequence[int] = (0, 1), padding_blocks: Sequence[int] = (2, 3), flatten: bool = False) -> None: super().__init__( depth=depth, pooling_blocks=pooling_blocks, padding_blocks=padding_blocks, flatten=flatten)
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.