SegNeXt/bricks.py
2023-04-07 22:29:25 +08:00

310 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
class DropPath(nn.Module):
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if not self.training or self.drop_prob == 0.:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
"""
逐层卷积
"""
class DepthwiseConv(nn.Module):
"""
in_channels: 输入通道数
out_channels: 输出通道数
kernel_size: 卷积核大小,元组类型
padding: 补充
stride: 步长
"""
def __init__(self, in_channels, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1), bias=False):
super(DepthwiseConv, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
groups=in_channels,
bias=bias
)
def forward(self, x):
out = self.conv(x)
return out
"""
逐点卷积
"""
class PointwiseConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(PointwiseConv, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0)
)
def forward(self, x):
out = self.conv(x)
return out
"""
深度可分离卷积
"""
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)):
super(DepthwiseSeparableConv, self).__init__()
self.conv1 = DepthwiseConv(
in_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride
)
self.conv2 = PointwiseConv(
in_channels=in_channels,
out_channels=out_channels
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
return out
"""
下采样
[batch_size, in_channels, height, width] -> [batch_size, out_channels, height // stride, width // stride]
"""
class DownSampling(nn.Module):
"""
in_channels: 输入通道数
out_channels: 输出通道数
kernel_size: 卷积核大小
stride: 步长
norm_layer: 正则化层如果为None使用BatchNorm
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, norm_layer=None):
super(DownSampling, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size[0] // 2, kernel_size[-1] // 2)
)
if norm_layer is None:
self.norm = nn.BatchNorm2d(num_features=out_channels)
else:
self.norm = norm_layer
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
return out
class _MatrixDecomposition2DBase(nn.Module):
def __init__(
self,
args=json.dumps(
{
"SPATIAL": True,
"MD_S": 1,
"MD_D": 512,
"MD_R": 64,
"TRAIN_STEPS": 6,
"EVAL_STEPS": 7,
"INV_T": 100,
"ETA": 0.9,
"RAND_INIT": True,
"return_bases": False,
"device": "cuda"
}
)
):
super(_MatrixDecomposition2DBase, self).__init__()
args: dict = json.loads(args)
for k, v in args.items():
setattr(self, k, v)
@abstractmethod
def _build_bases(self, batch_size):
pass
@abstractmethod
def local_step(self, x, bases, coef):
pass
@torch.no_grad()
def local_inference(self, x, bases):
# (batch_size * MD_S, MD_D, N)^T @ (batch_size * MD_S, MD_D, MD_R) -> (batchszie * MD_S, N, MD_R)
coef = torch.bmm(x.transpose(1, 2), bases)
coef = F.softmax(self.INV_T * coef, dim=-1)
steps = self.TRAIN_STEPS if self.training else self.EVAL_STEPS
for _ in range(steps):
bases, coef = self.local_step(x, bases, coef)
return bases, coef
@abstractmethod
def compute_coef(self, x, bases, coef):
pass
def forward(self, x):
batch_size, channels, height, width = x.shape
# (batch_size, channels, height, width) -> (batch_size * MD_S, MD_D, N)
if self.SPATIAL:
self.MD_D = channels // self.MD_S
N = height * width
x = x.view(batch_size * self.MD_S, self.MD_D, N)
else:
self.MD_D = height * width
N = channels // self.MD_S
x = x.view(batch_size * self.MD_S, N, self.MD_D).transpose(1, 2)
if not self.RAND_INIT and not hasattr(self, 'bases'):
bases = self._build_bases(1)
self.register_buffer('bases', bases)
# (MD_S, MD_D, MD_R) -> (batch_size * MD_S, MD_D, MD_R)
if self.RAND_INIT:
bases = self._build_bases(batch_size)
else:
bases = self.bases.repeat(batch_size, 1, 1)
bases, coef = self.local_inference(x, bases)
# (batch_size * MD_S, N, MD_R)
coef = self.compute_coef(x, bases, coef)
# (batch_size * MD_S, MD_D, MD_R) @ (batch_size * MD_S, N, MD_R)^T -> (batch_size * MD_S, MD_D, N)
x = torch.bmm(bases, coef.transpose(1, 2))
# (batch_size * MD_S, MD_D, N) -> (batch_size, channels, height, width)
if self.SPATIAL:
x = x.view(batch_size, channels, height, width)
else:
x = x.transpose(1, 2).view(batch_size, channels, height, width)
# (batch_size * height, MD_D, MD_R) -> (batch_size, height, N, MD_D)
bases = bases.view(batch_size, self.MD_S, self.MD_D, self.MD_R)
if self.return_bases:
return x, bases
return x
class NMF2D(_MatrixDecomposition2DBase):
def __init__(
self,
args=json.dumps(
{
"SPATIAL": True,
"MD_S": 1,
"MD_D": 512,
"MD_R": 64,
"TRAIN_STEPS": 6,
"EVAL_STEPS": 7,
"INV_T": 1,
"ETA": 0.9,
"RAND_INIT": True,
"return_bases": False,
"device": "cuda"
}
)
):
super(NMF2D, self).__init__(args)
def _build_bases(self, batch_size):
bases = torch.rand((batch_size * self.MD_S, self.MD_D, self.MD_R)).to(self.device)
bases = F.normalize(bases, dim=1)
return bases
# @torch.no_grad()
def local_step(self, x, bases, coef):
# (batch_size * MD_S, MD_D, N)^T @ (batch_size * MD_S, MD_D, MD_R) -> (batch_size * MD_S, N, MD_R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (batch_size * MD_S, N, MD_R) @ [(batch_size * MD_S, MD_D, MD_R)^T @ (batch_size * MD_S, MD_D, MD_R)]
# -> (batch_size * MD_S, N, MD_R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# Multiplicative Update
coef = coef * numerator / (denominator + 1e-6)
# (batch_size * MD_S, MD_D, N) @ (batch_size * MD_S, N, MD_R) -> (batch_size * MD_S, MD_D, MD_R)
numerator = torch.bmm(x, coef)
# (batch_size * MD_S, MD_D, MD_R) @ [(batch_size * MD_S, N, MD_R)^T @ (batch_size * MD_S, N, MD_R)]
# -> (batch_size * MD_S, D, MD_R)
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
# Multiplicative Update
bases = bases * numerator / (denominator + 1e-6)
return bases, coef
def compute_coef(self, x, bases, coef):
# (batch_size * MD_S, MD_D, N)^T @ (batch_size * MD_S, MD_D, MD_R) -> (batch_size * MD_S, N, MD_R)
numerator = torch.bmm(x.transpose(1, 2), bases)
# (batch_size * MD_S, N, MD_R) @ (batch_size * MD_S, MD_D, MD_R)^T @ (batch_size * MD_S, MD_D, MD_R)
# -> (batch_size * MD_S, N, MD_R)
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
# multiplication update
coef = coef * numerator / (denominator + 1e-6)
return coef
if __name__ == "__main__":
a = torch.ones(2, 3, 128, 128).to(device="cuda")
n = NMF2D(
json.dumps(
{
"SPATIAL": True,
"MD_S": 1,
"MD_D": 512,
"MD_R": 16,
"TRAIN_STEPS": 6,
"EVAL_STEPS": 7,
"INV_T": 1,
"ETA": 0.9,
"RAND_INIT": True,
"return_bases": False,
"device": "cuda"
}
)
)
print(n(a).shape)