semdisdiffae / fcdm_diffae /encoder.py
JTriggerFish
Fix posterior VP interpolation to use float32 precision
3196863
"""FCDM DiffAE encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
No input RMSNorm (use_other_outer_rms_norms=False during training).
Post-bottleneck RMSNorm (affine=False) on the mean branch.
Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from .fcdm_block import FCDMBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
@dataclass(frozen=True)
class EncoderPosterior:
"""VP-parameterized diagonal Gaussian posterior.
mean: Clean signal branch mu [B, bottleneck_dim, h, w]
logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
"""
mean: Tensor
logsnr: Tensor
@property
def alpha(self) -> Tensor:
"""VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32."""
logsnr_fp32 = self.logsnr.to(torch.float32)
return torch.sigmoid(logsnr_fp32).sqrt()
@property
def sigma(self) -> Tensor:
"""VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32."""
logsnr_fp32 = self.logsnr.to(torch.float32)
return torch.sigmoid(-logsnr_fp32).sqrt()
def mode(self) -> Tensor:
"""Posterior mode in token space: alpha * mean, computed in float32."""
return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)
def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
"""Sample from posterior: alpha * mean + sigma * eps, computed in float32."""
mean_fp32 = self.mean.to(torch.float32)
eps = torch.randn(
mean_fp32.shape, device=mean_fp32.device, dtype=torch.float32,
generator=generator,
)
return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)
class Encoder(nn.Module):
"""Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
With diagonal_gaussian posterior, the to_bottleneck projection outputs
2 * bottleneck_dim channels, split into mean and logsnr. The default
encode() returns the posterior mode: alpha * RMSNorm(mean).
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
bottleneck_posterior_kind: str = "diagonal_gaussian",
bottleneck_norm_mode: str = "disabled",
) -> None:
super().__init__()
self.bottleneck_dim = int(bottleneck_dim)
self.bottleneck_posterior_kind = bottleneck_posterior_kind
self.bottleneck_norm_mode = bottleneck_norm_mode
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.blocks = nn.ModuleList(
[
FCDMBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=False,
)
for _ in range(depth)
]
)
out_dim = (
2 * bottleneck_dim
if bottleneck_posterior_kind == "diagonal_gaussian"
else bottleneck_dim
)
self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
if bottleneck_norm_mode == "channel_wise":
self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
else:
self.norm_out = nn.Identity()
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
"""Encode images and return the full posterior (mean + logsnr).
Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
return EncoderPosterior(mean=mean, logsnr=logsnr)
def forward(self, images: Tensor) -> Tensor:
"""Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
Returns posterior mode (alpha * mean) for diagonal_gaussian,
or deterministic latents otherwise.
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
if self.bottleneck_posterior_kind == "diagonal_gaussian":
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
logsnr_fp32 = logsnr.to(torch.float32)
alpha = torch.sigmoid(logsnr_fp32).sqrt()
return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
z = self.norm_out(projection)
return z