semdisdiffae / fcdm_diffae /fcdm_block.py
data-archetype's picture
Upload folder using huggingface_hub
433bab6 verified
"""FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .norms import ChannelWiseRMSNorm
class GRN(nn.Module):
"""Global Response Normalization for NCHW tensors."""
def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
super().__init__()
self.eps: float = float(eps)
c = int(channels)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
def forward(self, x: Tensor) -> Tensor:
g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
g_fp32 = g.to(dtype=torch.float32)
n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
gamma = self.gamma.to(device=x.device, dtype=x.dtype)
beta = self.beta.to(device=x.device, dtype=x.dtype)
return gamma * (x * n) + beta + x
class FCDMBlock(nn.Module):
"""ConvNeXt-style block with scale+gate AdaLN and GRN.
Two modes:
- Unconditioned (encoder): uses learned layer-scale for near-identity init.
- External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
The gate is applied raw (no tanh).
"""
def __init__(
self,
channels: int,
mlp_ratio: float,
*,
depthwise_kernel_size: int = 7,
use_external_adaln: bool = False,
norm_eps: float = 1e-6,
layer_scale_init: float = 1e-3,
) -> None:
super().__init__()
self.channels: int = int(channels)
self.mlp_ratio: float = float(mlp_ratio)
self.dwconv = nn.Conv2d(
channels,
channels,
kernel_size=depthwise_kernel_size,
padding=depthwise_kernel_size // 2,
stride=1,
groups=channels,
bias=True,
)
self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
hidden = max(int(float(channels) * float(mlp_ratio)), 1)
self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
self.grn = GRN(hidden, eps=1e-6)
self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)
if not use_external_adaln:
self.layer_scale = nn.Parameter(
torch.full((channels,), float(layer_scale_init))
)
else:
self.register_parameter("layer_scale", None)
def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
b, c, _, _ = x.shape
if adaln_m is not None:
m = adaln_m.to(device=x.device, dtype=x.dtype)
scale, gate = m.chunk(2, dim=-1)
else:
scale = gate = None
h = self.dwconv(x)
h = self.norm(h)
if scale is not None:
h = h * (1.0 + scale.view(b, c, 1, 1))
h = self.pwconv1(h)
h = F.gelu(h)
h = self.grn(h)
h = self.pwconv2(h)
if gate is not None:
gate_view = gate.view(b, c, 1, 1)
else:
gate_view = self.layer_scale.view(1, c, 1, 1).to( # type: ignore[union-attr]
device=h.device, dtype=h.dtype
)
return x + gate_view * h