File size: 4,877 Bytes
9a54ff6 433bab6 3196863 433bab6 3196863 433bab6 3196863 433bab6 3196863 433bab6 3196863 433bab6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """FCDM DiffAE encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
No input RMSNorm (use_other_outer_rms_norms=False during training).
Post-bottleneck RMSNorm (affine=False) on the mean branch.
Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from .fcdm_block import FCDMBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
@dataclass(frozen=True)
class EncoderPosterior:
"""VP-parameterized diagonal Gaussian posterior.
mean: Clean signal branch mu [B, bottleneck_dim, h, w]
logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
"""
mean: Tensor
logsnr: Tensor
@property
def alpha(self) -> Tensor:
"""VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32."""
logsnr_fp32 = self.logsnr.to(torch.float32)
return torch.sigmoid(logsnr_fp32).sqrt()
@property
def sigma(self) -> Tensor:
"""VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32."""
logsnr_fp32 = self.logsnr.to(torch.float32)
return torch.sigmoid(-logsnr_fp32).sqrt()
def mode(self) -> Tensor:
"""Posterior mode in token space: alpha * mean, computed in float32."""
return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)
def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
"""Sample from posterior: alpha * mean + sigma * eps, computed in float32."""
mean_fp32 = self.mean.to(torch.float32)
eps = torch.randn(
mean_fp32.shape, device=mean_fp32.device, dtype=torch.float32,
generator=generator,
)
return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)
class Encoder(nn.Module):
"""Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
With diagonal_gaussian posterior, the to_bottleneck projection outputs
2 * bottleneck_dim channels, split into mean and logsnr. The default
encode() returns the posterior mode: alpha * RMSNorm(mean).
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
bottleneck_posterior_kind: str = "diagonal_gaussian",
bottleneck_norm_mode: str = "disabled",
) -> None:
super().__init__()
self.bottleneck_dim = int(bottleneck_dim)
self.bottleneck_posterior_kind = bottleneck_posterior_kind
self.bottleneck_norm_mode = bottleneck_norm_mode
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.blocks = nn.ModuleList(
[
FCDMBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=False,
)
for _ in range(depth)
]
)
out_dim = (
2 * bottleneck_dim
if bottleneck_posterior_kind == "diagonal_gaussian"
else bottleneck_dim
)
self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
if bottleneck_norm_mode == "channel_wise":
self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
else:
self.norm_out = nn.Identity()
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
"""Encode images and return the full posterior (mean + logsnr).
Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
return EncoderPosterior(mean=mean, logsnr=logsnr)
def forward(self, images: Tensor) -> Tensor:
"""Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
Returns posterior mode (alpha * mean) for diagonal_gaussian,
or deterministic latents otherwise.
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
if self.bottleneck_posterior_kind == "diagonal_gaussian":
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
logsnr_fp32 = logsnr.to(torch.float32)
alpha = torch.sigmoid(logsnr_fp32).sqrt()
return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
z = self.norm_out(projection)
return z
|