| """FCDM DiffAE encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior. |
| |
| No input RMSNorm (use_other_outer_rms_norms=False during training). |
| Post-bottleneck RMSNorm (affine=False) on the mean branch. |
| Encoder outputs posterior mode by default: alpha * RMSNorm(mean). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from .fcdm_block import FCDMBlock |
| from .norms import ChannelWiseRMSNorm |
| from .straight_through_encoder import Patchify |
|
|
|
|
| @dataclass(frozen=True) |
| class EncoderPosterior: |
| """VP-parameterized diagonal Gaussian posterior. |
| |
| mean: Clean signal branch mu [B, bottleneck_dim, h, w] |
| logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w] |
| """ |
|
|
| mean: Tensor |
| logsnr: Tensor |
|
|
| @property |
| def alpha(self) -> Tensor: |
| """VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32.""" |
| logsnr_fp32 = self.logsnr.to(torch.float32) |
| return torch.sigmoid(logsnr_fp32).sqrt() |
|
|
| @property |
| def sigma(self) -> Tensor: |
| """VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32.""" |
| logsnr_fp32 = self.logsnr.to(torch.float32) |
| return torch.sigmoid(-logsnr_fp32).sqrt() |
|
|
| def mode(self) -> Tensor: |
| """Posterior mode in token space: alpha * mean, computed in float32.""" |
| return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype) |
|
|
| def sample(self, *, generator: torch.Generator | None = None) -> Tensor: |
| """Sample from posterior: alpha * mean + sigma * eps, computed in float32.""" |
| mean_fp32 = self.mean.to(torch.float32) |
| eps = torch.randn( |
| mean_fp32.shape, device=mean_fp32.device, dtype=torch.float32, |
| generator=generator, |
| ) |
| return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype) |
|
|
|
|
| class Encoder(nn.Module): |
| """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w]. |
| |
| With diagonal_gaussian posterior, the to_bottleneck projection outputs |
| 2 * bottleneck_dim channels, split into mean and logsnr. The default |
| encode() returns the posterior mode: alpha * RMSNorm(mean). |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| patch_size: int, |
| model_dim: int, |
| depth: int, |
| bottleneck_dim: int, |
| mlp_ratio: float, |
| depthwise_kernel_size: int, |
| bottleneck_posterior_kind: str = "diagonal_gaussian", |
| bottleneck_norm_mode: str = "disabled", |
| ) -> None: |
| super().__init__() |
| self.bottleneck_dim = int(bottleneck_dim) |
| self.bottleneck_posterior_kind = bottleneck_posterior_kind |
| self.bottleneck_norm_mode = bottleneck_norm_mode |
| self.patchify = Patchify(in_channels, patch_size, model_dim) |
| self.blocks = nn.ModuleList( |
| [ |
| FCDMBlock( |
| model_dim, |
| mlp_ratio, |
| depthwise_kernel_size=depthwise_kernel_size, |
| use_external_adaln=False, |
| ) |
| for _ in range(depth) |
| ] |
| ) |
| out_dim = ( |
| 2 * bottleneck_dim |
| if bottleneck_posterior_kind == "diagonal_gaussian" |
| else bottleneck_dim |
| ) |
| self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True) |
| if bottleneck_norm_mode == "channel_wise": |
| self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False) |
| else: |
| self.norm_out = nn.Identity() |
|
|
| def encode_posterior(self, images: Tensor) -> EncoderPosterior: |
| """Encode images and return the full posterior (mean + logsnr). |
| |
| Only valid when bottleneck_posterior_kind == "diagonal_gaussian". |
| """ |
| z = self.patchify(images) |
| for block in self.blocks: |
| z = block(z) |
| projection = self.to_bottleneck(z) |
| mean, logsnr = projection.chunk(2, dim=1) |
| mean = self.norm_out(mean) |
| return EncoderPosterior(mean=mean, logsnr=logsnr) |
|
|
| def forward(self, images: Tensor) -> Tensor: |
| """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w]. |
| |
| Returns posterior mode (alpha * mean) for diagonal_gaussian, |
| or deterministic latents otherwise. |
| """ |
| z = self.patchify(images) |
| for block in self.blocks: |
| z = block(z) |
| projection = self.to_bottleneck(z) |
| if self.bottleneck_posterior_kind == "diagonal_gaussian": |
| mean, logsnr = projection.chunk(2, dim=1) |
| mean = self.norm_out(mean) |
| logsnr_fp32 = logsnr.to(torch.float32) |
| alpha = torch.sigmoid(logsnr_fp32).sqrt() |
| return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype) |
| z = self.norm_out(projection) |
| return z |
|
|