"""FCDM DiffAE encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior. No input RMSNorm (use_other_outer_rms_norms=False during training). Post-bottleneck RMSNorm (affine=False) on the mean branch. Encoder outputs posterior mode by default: alpha * RMSNorm(mean). """ from __future__ import annotations from dataclasses import dataclass import torch from torch import Tensor, nn from .fcdm_block import FCDMBlock from .norms import ChannelWiseRMSNorm from .straight_through_encoder import Patchify @dataclass(frozen=True) class EncoderPosterior: """VP-parameterized diagonal Gaussian posterior. mean: Clean signal branch mu [B, bottleneck_dim, h, w] logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w] """ mean: Tensor logsnr: Tensor @property def alpha(self) -> Tensor: """VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32.""" logsnr_fp32 = self.logsnr.to(torch.float32) return torch.sigmoid(logsnr_fp32).sqrt() @property def sigma(self) -> Tensor: """VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32.""" logsnr_fp32 = self.logsnr.to(torch.float32) return torch.sigmoid(-logsnr_fp32).sqrt() def mode(self) -> Tensor: """Posterior mode in token space: alpha * mean, computed in float32.""" return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype) def sample(self, *, generator: torch.Generator | None = None) -> Tensor: """Sample from posterior: alpha * mean + sigma * eps, computed in float32.""" mean_fp32 = self.mean.to(torch.float32) eps = torch.randn( mean_fp32.shape, device=mean_fp32.device, dtype=torch.float32, generator=generator, ) return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype) class Encoder(nn.Module): """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w]. With diagonal_gaussian posterior, the to_bottleneck projection outputs 2 * bottleneck_dim channels, split into mean and logsnr. The default encode() returns the posterior mode: alpha * RMSNorm(mean). """ def __init__( self, in_channels: int, patch_size: int, model_dim: int, depth: int, bottleneck_dim: int, mlp_ratio: float, depthwise_kernel_size: int, bottleneck_posterior_kind: str = "diagonal_gaussian", bottleneck_norm_mode: str = "disabled", ) -> None: super().__init__() self.bottleneck_dim = int(bottleneck_dim) self.bottleneck_posterior_kind = bottleneck_posterior_kind self.bottleneck_norm_mode = bottleneck_norm_mode self.patchify = Patchify(in_channels, patch_size, model_dim) self.blocks = nn.ModuleList( [ FCDMBlock( model_dim, mlp_ratio, depthwise_kernel_size=depthwise_kernel_size, use_external_adaln=False, ) for _ in range(depth) ] ) out_dim = ( 2 * bottleneck_dim if bottleneck_posterior_kind == "diagonal_gaussian" else bottleneck_dim ) self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True) if bottleneck_norm_mode == "channel_wise": self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False) else: self.norm_out = nn.Identity() def encode_posterior(self, images: Tensor) -> EncoderPosterior: """Encode images and return the full posterior (mean + logsnr). Only valid when bottleneck_posterior_kind == "diagonal_gaussian". """ z = self.patchify(images) for block in self.blocks: z = block(z) projection = self.to_bottleneck(z) mean, logsnr = projection.chunk(2, dim=1) mean = self.norm_out(mean) return EncoderPosterior(mean=mean, logsnr=logsnr) def forward(self, images: Tensor) -> Tensor: """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w]. Returns posterior mode (alpha * mean) for diagonal_gaussian, or deterministic latents otherwise. """ z = self.patchify(images) for block in self.blocks: z = block(z) projection = self.to_bottleneck(z) if self.bottleneck_posterior_kind == "diagonal_gaussian": mean, logsnr = projection.chunk(2, dim=1) mean = self.norm_out(mean) logsnr_fp32 = logsnr.to(torch.float32) alpha = torch.sigmoid(logsnr_fp32).sqrt() return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype) z = self.norm_out(projection) return z