"""Frozen model architecture and user-tunable inference configuration.""" from __future__ import annotations import json from dataclasses import asdict, dataclass from pathlib import Path @dataclass(frozen=True) class FCDMDiffAEConfig: """Frozen model architecture config. Stored alongside weights as config.json.""" in_channels: int = 3 patch_size: int = 16 model_dim: int = 896 encoder_depth: int = 4 decoder_depth: int = 8 decoder_start_blocks: int = 2 decoder_end_blocks: int = 2 bottleneck_dim: int = 128 mlp_ratio: float = 4.0 depthwise_kernel_size: int = 7 adaln_low_rank_rank: int = 128 # Encoder posterior kind: "diagonal_gaussian" or "deterministic" bottleneck_posterior_kind: str = "diagonal_gaussian" # Post-bottleneck normalization: "channel_wise" or "disabled" bottleneck_norm_mode: str = "disabled" # Bottleneck patchification: "off" or "patch_2x2" # When "patch_2x2", encoder latents are 2x2 patchified after the bottleneck # (channels * 4, spatial / 2), and decode unpatchifies before the decoder. bottleneck_patchify_mode: str = "off" # VP diffusion schedule endpoints logsnr_min: float = -10.0 logsnr_max: float = 10.0 # Pixel-space noise std for VP diffusion initialization pixel_noise_std: float = 0.558 @property def latent_channels(self) -> int: """Channel width of the exported latent space.""" if self.bottleneck_patchify_mode == "patch_2x2": return self.bottleneck_dim * 4 return self.bottleneck_dim @property def effective_patch_size(self) -> int: """Effective spatial stride from image to latent grid.""" if self.bottleneck_patchify_mode == "patch_2x2": return self.patch_size * 2 return self.patch_size def save(self, path: str | Path) -> None: """Save config as JSON.""" p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(json.dumps(asdict(self), indent=2) + "\n") @classmethod def load(cls, path: str | Path) -> FCDMDiffAEConfig: """Load config from JSON.""" data = json.loads(Path(path).read_text()) return cls(**data) @dataclass class FCDMDiffAEInferenceConfig: """User-tunable inference parameters with sensible defaults. PDG (Path-Drop Guidance) sharpens reconstructions by degrading conditioning in one pass and amplifying the difference. When enabled, uses 2 NFE per step. Recommended: ``pdg=True, pdg_strength=2.0, num_steps=10``. """ num_steps: int = 1 # number of denoising steps (NFE) sampler: str = "ddim" # "ddim" or "dpmpp_2m" schedule: str = "linear" # "linear" or "cosine" pdg: bool = False # enable PDG for perceptual sharpening pdg_strength: float = 2.0 # CFG-like strength when pdg=True seed: int | None = None