| """FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN.""" |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
| from .norms import ChannelWiseRMSNorm |
|
|
|
|
| class GRN(nn.Module): |
| """Global Response Normalization for NCHW tensors.""" |
|
|
| def __init__(self, channels: int, *, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.eps: float = float(eps) |
| c = int(channels) |
| self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32)) |
| self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True) |
| g_fp32 = g.to(dtype=torch.float32) |
| n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype) |
| gamma = self.gamma.to(device=x.device, dtype=x.dtype) |
| beta = self.beta.to(device=x.device, dtype=x.dtype) |
| return gamma * (x * n) + beta + x |
|
|
|
|
| class FCDMBlock(nn.Module): |
| """ConvNeXt-style block with scale+gate AdaLN and GRN. |
| |
| Two modes: |
| - Unconditioned (encoder): uses learned layer-scale for near-identity init. |
| - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate). |
| The gate is applied raw (no tanh). |
| """ |
|
|
| def __init__( |
| self, |
| channels: int, |
| mlp_ratio: float, |
| *, |
| depthwise_kernel_size: int = 7, |
| use_external_adaln: bool = False, |
| norm_eps: float = 1e-6, |
| layer_scale_init: float = 1e-3, |
| ) -> None: |
| super().__init__() |
| self.channels: int = int(channels) |
| self.mlp_ratio: float = float(mlp_ratio) |
|
|
| self.dwconv = nn.Conv2d( |
| channels, |
| channels, |
| kernel_size=depthwise_kernel_size, |
| padding=depthwise_kernel_size // 2, |
| stride=1, |
| groups=channels, |
| bias=True, |
| ) |
| self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False) |
| hidden = max(int(float(channels) * float(mlp_ratio)), 1) |
| self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True) |
| self.grn = GRN(hidden, eps=1e-6) |
| self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True) |
|
|
| if not use_external_adaln: |
| self.layer_scale = nn.Parameter( |
| torch.full((channels,), float(layer_scale_init)) |
| ) |
| else: |
| self.register_parameter("layer_scale", None) |
|
|
| def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor: |
| b, c, _, _ = x.shape |
|
|
| if adaln_m is not None: |
| m = adaln_m.to(device=x.device, dtype=x.dtype) |
| scale, gate = m.chunk(2, dim=-1) |
| else: |
| scale = gate = None |
|
|
| h = self.dwconv(x) |
| h = self.norm(h) |
|
|
| if scale is not None: |
| h = h * (1.0 + scale.view(b, c, 1, 1)) |
|
|
| h = self.pwconv1(h) |
| h = F.gelu(h) |
| h = self.grn(h) |
| h = self.pwconv2(h) |
|
|
| if gate is not None: |
| gate_view = gate.view(b, c, 1, 1) |
| else: |
| gate_view = self.layer_scale.view(1, c, 1, 1).to( |
| device=h.device, dtype=h.dtype |
| ) |
|
|
| return x + gate_view * h |
|
|