-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
fpn_head.py
68 lines (57 loc) · 2.36 KB
/
fpn_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.registry import MODELS
from ..utils import Upsample, resize
from .decode_head import BaseDecodeHead
@MODELS.register_module()
class FPNHead(BaseDecodeHead):
"""Panoptic Feature Pyramid Networks.
This head is the implementation of `Semantic FPN
<https://arxiv.org/abs/1901.02446>`_.
Args:
feature_strides (tuple[int]): The strides for input feature maps.
stack_lateral. All strides suppose to be power of 2. The first
one is of largest resolution.
"""
def __init__(self, feature_strides, **kwargs):
super().__init__(input_transform='multiple_select', **kwargs)
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides
self.scale_heads = nn.ModuleList()
for i in range(len(feature_strides)):
head_length = max(
1,
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
scale_head = []
for k in range(head_length):
scale_head.append(
ConvModule(
self.in_channels[i] if k == 0 else self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))
self.scale_heads.append(nn.Sequential(*scale_head))
def forward(self, inputs):
x = self._transform_inputs(inputs)
output = self.scale_heads[0](x[0])
for i in range(1, len(self.feature_strides)):
# non inplace
output = output + resize(
self.scale_heads[i](x[i]),
size=output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.cls_seg(output)
return output