File size: 3,376 Bytes
433bab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""FCDM block: ConvNeXt-style conv block with GRN and scale+gate AdaLN."""

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from .norms import ChannelWiseRMSNorm


class GRN(nn.Module):
    """Global Response Normalization for NCHW tensors."""

    def __init__(self, channels: int, *, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps: float = float(eps)
        c = int(channels)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))
        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1), dtype=torch.float32))

    def forward(self, x: Tensor) -> Tensor:
        g = torch.linalg.vector_norm(x, ord=2, dim=(2, 3), keepdim=True)
        g_fp32 = g.to(dtype=torch.float32)
        n = (g_fp32 / (g_fp32.mean(dim=1, keepdim=True) + self.eps)).to(dtype=x.dtype)
        gamma = self.gamma.to(device=x.device, dtype=x.dtype)
        beta = self.beta.to(device=x.device, dtype=x.dtype)
        return gamma * (x * n) + beta + x


class FCDMBlock(nn.Module):
    """ConvNeXt-style block with scale+gate AdaLN and GRN.

    Two modes:
    - Unconditioned (encoder): uses learned layer-scale for near-identity init.
    - External AdaLN (decoder): receives packed [B, 2*C] modulation (scale, gate).
      The gate is applied raw (no tanh).
    """

    def __init__(
        self,
        channels: int,
        mlp_ratio: float,
        *,
        depthwise_kernel_size: int = 7,
        use_external_adaln: bool = False,
        norm_eps: float = 1e-6,
        layer_scale_init: float = 1e-3,
    ) -> None:
        super().__init__()
        self.channels: int = int(channels)
        self.mlp_ratio: float = float(mlp_ratio)

        self.dwconv = nn.Conv2d(
            channels,
            channels,
            kernel_size=depthwise_kernel_size,
            padding=depthwise_kernel_size // 2,
            stride=1,
            groups=channels,
            bias=True,
        )
        self.norm = ChannelWiseRMSNorm(channels, eps=float(norm_eps), affine=False)
        hidden = max(int(float(channels) * float(mlp_ratio)), 1)
        self.pwconv1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=True)
        self.grn = GRN(hidden, eps=1e-6)
        self.pwconv2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=True)

        if not use_external_adaln:
            self.layer_scale = nn.Parameter(
                torch.full((channels,), float(layer_scale_init))
            )
        else:
            self.register_parameter("layer_scale", None)

    def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
        b, c, _, _ = x.shape

        if adaln_m is not None:
            m = adaln_m.to(device=x.device, dtype=x.dtype)
            scale, gate = m.chunk(2, dim=-1)
        else:
            scale = gate = None

        h = self.dwconv(x)
        h = self.norm(h)

        if scale is not None:
            h = h * (1.0 + scale.view(b, c, 1, 1))

        h = self.pwconv1(h)
        h = F.gelu(h)
        h = self.grn(h)
        h = self.pwconv2(h)

        if gate is not None:
            gate_view = gate.view(b, c, 1, 1)
        else:
            gate_view = self.layer_scale.view(1, c, 1, 1).to(  # type: ignore[union-attr]
                device=h.device, dtype=h.dtype
            )

        return x + gate_view * h