534 lines
16 KiB
Python
534 lines
16 KiB
Python
import json
|
|
import math
|
|
|
|
import torch.nn as nn
|
|
import torch
|
|
import bricks
|
|
import torch.nn.functional as F
|
|
from abc import *
|
|
import utils
|
|
|
|
"""
|
|
[batch_size, in_channels, height, width] -> [batch_size, out_channels, height // 4, width // 4]
|
|
"""
|
|
class StemConv(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, norm_layer=None):
|
|
super(StemConv, self).__init__()
|
|
|
|
self.proj = nn.Sequential(
|
|
bricks.DownSampling(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels // 2,
|
|
kernel_size=(3, 3),
|
|
stride=(2, 2),
|
|
norm_layer=norm_layer
|
|
),
|
|
bricks.DownSampling(
|
|
in_channels=out_channels // 2,
|
|
out_channels=out_channels,
|
|
kernel_size=(3, 3),
|
|
stride=(2, 2),
|
|
norm_layer=norm_layer
|
|
),
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.proj(x)
|
|
return out
|
|
|
|
|
|
class MSCA(nn.Module):
|
|
|
|
def __init__(self, in_channels):
|
|
super(MSCA, self).__init__()
|
|
|
|
self.conv = bricks.DepthwiseConv(
|
|
in_channels=in_channels,
|
|
kernel_size=(5, 5),
|
|
padding=(2, 2),
|
|
bias=True
|
|
)
|
|
|
|
|
|
self.conv7 = nn.Sequential(
|
|
bricks.DepthwiseConv(
|
|
in_channels=in_channels,
|
|
kernel_size=(1, 7),
|
|
padding=(0, 3),
|
|
bias=True
|
|
),
|
|
bricks.DepthwiseConv(
|
|
in_channels=in_channels,
|
|
kernel_size=(7, 1),
|
|
padding=(3, 0),
|
|
bias=True
|
|
)
|
|
)
|
|
|
|
self.conv11 = nn.Sequential(
|
|
bricks.DepthwiseConv(
|
|
in_channels=in_channels,
|
|
kernel_size=(1, 11),
|
|
padding=(0, 5),
|
|
bias=True
|
|
),
|
|
bricks.DepthwiseConv(
|
|
in_channels=in_channels,
|
|
kernel_size=(11, 1),
|
|
padding=(5, 0),
|
|
bias=True
|
|
)
|
|
)
|
|
|
|
self.conv21 = nn.Sequential(
|
|
bricks.DepthwiseConv(
|
|
in_channels=in_channels,
|
|
kernel_size=(1, 21),
|
|
padding=(0, 10),
|
|
bias=True
|
|
),
|
|
bricks.DepthwiseConv(
|
|
in_channels=in_channels,
|
|
kernel_size=(21, 1),
|
|
padding=(10, 0),
|
|
bias=True
|
|
)
|
|
)
|
|
|
|
self.fc = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=(1, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
u = x
|
|
out = self.conv(x)
|
|
|
|
branch1 = self.conv7(out)
|
|
branch2 = self.conv11(out)
|
|
branch3 = self.conv21(out)
|
|
|
|
out = self.fc(out + branch1 + branch2 + branch3)
|
|
out = out * u
|
|
return out
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, in_channels):
|
|
super(Attention, self).__init__()
|
|
|
|
self.fc1 = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=(1, 1)
|
|
)
|
|
self.msca = MSCA(in_channels=in_channels)
|
|
self.fc2 = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
kernel_size=(1, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = F.gelu(self.fc1(x))
|
|
out = self.msca(out)
|
|
out = self.fc2(out)
|
|
return out
|
|
|
|
|
|
class FFN(nn.Module):
|
|
|
|
def __init__(self, in_features, hidden_features, out_features, drop_prob=0.):
|
|
super(FFN, self).__init__()
|
|
|
|
self.fc1 = nn.Conv2d(
|
|
in_channels=in_features,
|
|
out_channels=hidden_features,
|
|
kernel_size=(1, 1)
|
|
)
|
|
self.dw = bricks.DepthwiseConv(
|
|
in_channels=hidden_features,
|
|
kernel_size=(3, 3),
|
|
bias=True
|
|
)
|
|
self.fc2 = nn.Conv2d(
|
|
in_channels=hidden_features,
|
|
out_channels=out_features,
|
|
kernel_size=(1, 1)
|
|
)
|
|
self.dropout = nn.Dropout(drop_prob)
|
|
|
|
def forward(self, x):
|
|
out = self.fc1(x)
|
|
out = F.gelu(self.dw(out))
|
|
out = self.fc2(out)
|
|
out = self.dropout(out)
|
|
return out
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(self, in_channels, expand_ratio, drop_prob=0., drop_path_prob=0.):
|
|
super(Block, self).__init__()
|
|
|
|
|
|
self.norm1 = nn.BatchNorm2d(num_features=in_channels)
|
|
self.attention = Attention(in_channels=in_channels)
|
|
self.drop_path = bricks.DropPath(drop_prob=drop_path_prob if drop_path_prob >= 0 else nn.Identity)
|
|
self.norm2 = nn.BatchNorm2d(num_features=in_channels)
|
|
self.ffn = FFN(
|
|
in_features=in_channels,
|
|
hidden_features=int(expand_ratio * in_channels),
|
|
out_features=in_channels,
|
|
drop_prob=drop_prob
|
|
)
|
|
|
|
layer_scale_init_value = 1e-2
|
|
self.layer_scale1 = nn.Parameter(
|
|
layer_scale_init_value * torch.ones(in_channels),
|
|
requires_grad=True
|
|
)
|
|
self.layer_scale2 = nn.Parameter(
|
|
layer_scale_init_value * torch.ones(in_channels),
|
|
requires_grad=True
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.norm1(x)
|
|
out = self.attention(out)
|
|
out = x + self.drop_path(
|
|
self.layer_scale1.unsqueeze(-1).unsqueeze(-1) * out
|
|
)
|
|
x = out
|
|
|
|
out = self.norm2(out)
|
|
out = self.ffn(out)
|
|
out = x + self.drop_path(
|
|
self.layer_scale2.unsqueeze(-1).unsqueeze(-1) * out
|
|
)
|
|
|
|
return out
|
|
|
|
class Stage(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
stage_id,
|
|
in_channels,
|
|
out_channels,
|
|
expand_ratio,
|
|
blocks_num,
|
|
drop_prob=0.,
|
|
drop_path_prob=[0.]
|
|
):
|
|
super(Stage, self).__init__()
|
|
|
|
|
|
assert blocks_num == len(drop_path_prob)
|
|
|
|
if stage_id == 0:
|
|
self.down_sampling = StemConv(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels
|
|
)
|
|
else:
|
|
self.down_sampling = bricks.DownSampling(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=(3, 3),
|
|
stride=(2, 2)
|
|
)
|
|
|
|
self.blocks = nn.Sequential(
|
|
*[
|
|
Block(
|
|
in_channels=out_channels,
|
|
expand_ratio=expand_ratio,
|
|
drop_prob=drop_prob,
|
|
drop_path_prob=drop_path_prob[i]
|
|
) for i in range(0, blocks_num)
|
|
]
|
|
)
|
|
|
|
self.norm = nn.LayerNorm(out_channels)
|
|
|
|
def forward(self, x):
|
|
out = self.down_sampling(x)
|
|
out = self.blocks(out)
|
|
# [batch_size, channels, height, width] -> [batch_size, channels, height * width]
|
|
batch_size, channels, height, width = out.shape
|
|
out = out.view(batch_size, channels, -1)
|
|
# [batch_size, channels, height * width] -> [batch_size, height * width, channels]
|
|
out = torch.transpose(out, -2, -1)
|
|
out = self.norm(out)
|
|
|
|
# [batch_size, height * width, channels] -> [batch_size, channels, height * width]
|
|
out = torch.transpose(out, -2, -1)
|
|
# [batch_size, channels, height * width] -> [batch_size, channels, height, width]
|
|
out = out.view(batch_size, -1, height, width)
|
|
|
|
return out
|
|
|
|
class MSCAN(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dims=[3, 32, 64, 160, 256],
|
|
expand_ratios=[8, 8, 4, 4],
|
|
depths=[3, 3, 5, 2],
|
|
drop_prob=0.1,
|
|
drop_path_prob=0.1
|
|
):
|
|
super(MSCAN, self).__init__()
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_prob, sum(depths))]
|
|
self.stages = nn.Sequential(
|
|
*[
|
|
Stage(
|
|
stage_id=stage_id,
|
|
in_channels=embed_dims[stage_id],
|
|
out_channels=embed_dims[stage_id + 1],
|
|
expand_ratio=expand_ratios[stage_id],
|
|
blocks_num=depths[stage_id],
|
|
drop_prob=drop_prob,
|
|
drop_path_prob=dpr[sum(depths[: stage_id]): sum(depths[: stage_id + 1])]
|
|
) for stage_id in range(0, len(depths))
|
|
]
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = x
|
|
outputs = []
|
|
|
|
for idx, stage in enumerate(self.stages):
|
|
out = stage(out)
|
|
if idx != 0:
|
|
outputs.append(out)
|
|
|
|
# outputs: [output_of_stage1, output_of_stage2, output_of_stage3]
|
|
# output_of_stage1: [batch_size, embed_dims[2], height / 8, width / 8]
|
|
# output_of_stage2: [batch_size, embed_dims[3], height / 16, width / 16]
|
|
# output_of_stage3: [batch_size, embed_dims[4], height / 32, width / 32]
|
|
return [x, *outputs]
|
|
|
|
class Hamburger(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hamburger_channels=256,
|
|
nmf2d_config=json.dumps(
|
|
{
|
|
"SPATIAL": True,
|
|
"MD_S": 1,
|
|
"MD_D": 512,
|
|
"MD_R": 64,
|
|
"TRAIN_STEPS": 6,
|
|
"EVAL_STEPS": 7,
|
|
"INV_T": 1,
|
|
"ETA": 0.9,
|
|
"RAND_INIT": True,
|
|
"return_bases": False,
|
|
"device": "cuda"
|
|
}
|
|
)
|
|
):
|
|
super(Hamburger, self).__init__()
|
|
self.ham_in = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels=hamburger_channels,
|
|
out_channels=hamburger_channels,
|
|
kernel_size=(1, 1)
|
|
)
|
|
)
|
|
|
|
self.ham = bricks.NMF2D(args=nmf2d_config)
|
|
|
|
self.ham_out = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels=hamburger_channels,
|
|
out_channels=hamburger_channels,
|
|
kernel_size=(1, 1),
|
|
bias=False
|
|
),
|
|
nn.GroupNorm(
|
|
num_groups=32,
|
|
num_channels=hamburger_channels
|
|
)
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.ham_in(x)
|
|
out = self.ham(out)
|
|
out = self.ham_out(out)
|
|
out = F.relu(x + out)
|
|
return out
|
|
|
|
|
|
class LightHamHead(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels_list=[64, 160, 256],
|
|
hidden_channels=256,
|
|
out_channels=256,
|
|
classes_num=150,
|
|
drop_prob=0.1,
|
|
nmf2d_config=json.dumps(
|
|
{
|
|
"SPATIAL": True,
|
|
"MD_S": 1,
|
|
"MD_D": 512,
|
|
"MD_R": 64,
|
|
"TRAIN_STEPS": 6,
|
|
"EVAL_STEPS": 7,
|
|
"INV_T": 1,
|
|
"ETA": 0.9,
|
|
"RAND_INIT": True,
|
|
"return_bases": False,
|
|
"device": "cuda"
|
|
}
|
|
)
|
|
):
|
|
super(LightHamHead, self).__init__()
|
|
|
|
self.cls_seg = nn.Sequential(
|
|
nn.Dropout2d(drop_prob),
|
|
nn.Conv2d(
|
|
in_channels=out_channels,
|
|
out_channels=classes_num,
|
|
kernel_size=(1, 1)
|
|
)
|
|
)
|
|
|
|
self.squeeze = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels=sum(in_channels_list),
|
|
out_channels=hidden_channels,
|
|
kernel_size=(1, 1),
|
|
bias=False
|
|
),
|
|
nn.GroupNorm(
|
|
num_groups=32,
|
|
num_channels=hidden_channels,
|
|
),
|
|
nn.ReLU()
|
|
)
|
|
|
|
self.hamburger = Hamburger(
|
|
hamburger_channels=hidden_channels,
|
|
nmf2d_config=nmf2d_config
|
|
)
|
|
|
|
self.align = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=(1, 1),
|
|
bias=False
|
|
),
|
|
nn.GroupNorm(
|
|
num_groups=32,
|
|
num_channels=out_channels
|
|
),
|
|
nn.ReLU()
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# inputs: [x, x_1, x_2, x_3]
|
|
# x: [batch_size, channels, height, width]
|
|
def forward(self, inputs):
|
|
assert len(inputs) >= 2
|
|
o = inputs[0]
|
|
batch_size, _, standard_height, standard_width = inputs[1].shape
|
|
standard_shape = (standard_height, standard_width)
|
|
inputs = [
|
|
F.interpolate(
|
|
input=x,
|
|
size=standard_shape,
|
|
mode="bilinear",
|
|
align_corners=False
|
|
)
|
|
for x in inputs[1:]
|
|
]
|
|
|
|
# x: [batch_size, channels_1 + channels_2 + channels_3, standard_height, standard_width]
|
|
x = torch.cat(inputs, dim=1)
|
|
|
|
# out: [batch_size, channels_1 + channels_2 + channels_3, standard_height, standard_width]
|
|
out = self.squeeze(x)
|
|
out = self.hamburger(out)
|
|
out = self.align(out)
|
|
|
|
# out: [batch_size, classes_num, standard_height, standard_width]
|
|
out = self.cls_seg(out)
|
|
|
|
_, _, original_height, original_width = o.shape
|
|
# out: [batch_size, original_height * original_width, classes_num]
|
|
out = F.interpolate(
|
|
input=out,
|
|
size=(original_height, original_width),
|
|
mode="bilinear",
|
|
align_corners=False
|
|
)
|
|
out = torch.transpose(out.view(batch_size, -1, original_height * original_width), -2, -1)
|
|
|
|
return out
|
|
|
|
|
|
class SegNeXt(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dims=[3, 32, 64, 160, 256],
|
|
expand_rations=[8, 8, 4, 4],
|
|
depths=[3, 3, 5, 2],
|
|
drop_prob_of_encoder=0.1,
|
|
drop_path_prob=0.1,
|
|
hidden_channels=256,
|
|
out_channels=256,
|
|
classes_num=150,
|
|
drop_prob_of_decoder=0.1,
|
|
nmf2d_config=json.dumps(
|
|
{
|
|
"SPATIAL": True,
|
|
"MD_S": 1,
|
|
"MD_D": 512,
|
|
"MD_R": 64,
|
|
"TRAIN_STEPS": 6,
|
|
"EVAL_STEPS": 7,
|
|
"INV_T": 1,
|
|
"ETA": 0.9,
|
|
"RAND_INIT": False,
|
|
"return_bases": False,
|
|
"device": "cuda"
|
|
}
|
|
)
|
|
):
|
|
super(SegNeXt, self).__init__()
|
|
|
|
self.encoder = MSCAN(
|
|
embed_dims=embed_dims,
|
|
expand_ratios=expand_rations,
|
|
depths=depths,
|
|
drop_prob=drop_prob_of_encoder,
|
|
drop_path_prob=drop_path_prob
|
|
)
|
|
|
|
self.decoder = LightHamHead(
|
|
in_channels_list=embed_dims[-3:],
|
|
hidden_channels=hidden_channels,
|
|
out_channels=out_channels,
|
|
classes_num=classes_num,
|
|
drop_prob=drop_prob_of_decoder,
|
|
nmf2d_config=nmf2d_config
|
|
)
|
|
|
|
def forward(self, x):
|
|
out = self.encoder(x)
|
|
out = self.decoder(out)
|
|
return out |