File size: 4,877 Bytes
9a54ff6
433bab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3196863
 
 
433bab6
 
 
3196863
 
 
433bab6
 
3196863
 
433bab6
 
3196863
 
 
 
 
 
 
433bab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3196863
 
 
433bab6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""FCDM DiffAE encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.

No input RMSNorm (use_other_outer_rms_norms=False during training).
Post-bottleneck RMSNorm (affine=False) on the mean branch.
Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
"""

from __future__ import annotations

from dataclasses import dataclass

import torch
from torch import Tensor, nn

from .fcdm_block import FCDMBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify


@dataclass(frozen=True)
class EncoderPosterior:
    """VP-parameterized diagonal Gaussian posterior.

    mean: Clean signal branch mu [B, bottleneck_dim, h, w]
    logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
    """

    mean: Tensor
    logsnr: Tensor

    @property
    def alpha(self) -> Tensor:
        """VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32."""
        logsnr_fp32 = self.logsnr.to(torch.float32)
        return torch.sigmoid(logsnr_fp32).sqrt()

    @property
    def sigma(self) -> Tensor:
        """VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32."""
        logsnr_fp32 = self.logsnr.to(torch.float32)
        return torch.sigmoid(-logsnr_fp32).sqrt()

    def mode(self) -> Tensor:
        """Posterior mode in token space: alpha * mean, computed in float32."""
        return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)

    def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
        """Sample from posterior: alpha * mean + sigma * eps, computed in float32."""
        mean_fp32 = self.mean.to(torch.float32)
        eps = torch.randn(
            mean_fp32.shape, device=mean_fp32.device, dtype=torch.float32,
            generator=generator,
        )
        return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)


class Encoder(nn.Module):
    """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].

    With diagonal_gaussian posterior, the to_bottleneck projection outputs
    2 * bottleneck_dim channels, split into mean and logsnr. The default
    encode() returns the posterior mode: alpha * RMSNorm(mean).
    """

    def __init__(
        self,
        in_channels: int,
        patch_size: int,
        model_dim: int,
        depth: int,
        bottleneck_dim: int,
        mlp_ratio: float,
        depthwise_kernel_size: int,
        bottleneck_posterior_kind: str = "diagonal_gaussian",
        bottleneck_norm_mode: str = "disabled",
    ) -> None:
        super().__init__()
        self.bottleneck_dim = int(bottleneck_dim)
        self.bottleneck_posterior_kind = bottleneck_posterior_kind
        self.bottleneck_norm_mode = bottleneck_norm_mode
        self.patchify = Patchify(in_channels, patch_size, model_dim)
        self.blocks = nn.ModuleList(
            [
                FCDMBlock(
                    model_dim,
                    mlp_ratio,
                    depthwise_kernel_size=depthwise_kernel_size,
                    use_external_adaln=False,
                )
                for _ in range(depth)
            ]
        )
        out_dim = (
            2 * bottleneck_dim
            if bottleneck_posterior_kind == "diagonal_gaussian"
            else bottleneck_dim
        )
        self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
        if bottleneck_norm_mode == "channel_wise":
            self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
        else:
            self.norm_out = nn.Identity()

    def encode_posterior(self, images: Tensor) -> EncoderPosterior:
        """Encode images and return the full posterior (mean + logsnr).

        Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
        """
        z = self.patchify(images)
        for block in self.blocks:
            z = block(z)
        projection = self.to_bottleneck(z)
        mean, logsnr = projection.chunk(2, dim=1)
        mean = self.norm_out(mean)
        return EncoderPosterior(mean=mean, logsnr=logsnr)

    def forward(self, images: Tensor) -> Tensor:
        """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].

        Returns posterior mode (alpha * mean) for diagonal_gaussian,
        or deterministic latents otherwise.
        """
        z = self.patchify(images)
        for block in self.blocks:
            z = block(z)
        projection = self.to_bottleneck(z)
        if self.bottleneck_posterior_kind == "diagonal_gaussian":
            mean, logsnr = projection.chunk(2, dim=1)
            mean = self.norm_out(mean)
            logsnr_fp32 = logsnr.to(torch.float32)
            alpha = torch.sigmoid(logsnr_fp32).sqrt()
            return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
        z = self.norm_out(projection)
        return z