data-archetype's picture
Upload folder using huggingface_hub
433bab6 verified
"""Scale+Gate AdaLN (2-way) for FCDM decoder blocks."""
from __future__ import annotations
from torch import Tensor, nn
class AdaLNScaleGateZeroProjector(nn.Module):
"""Packed 2-way AdaLN projection (SiLU -> Linear), zero-initialized.
Outputs [B, 2*d_model] packed as (scale, gate).
"""
def __init__(self, d_model: int, d_cond: int) -> None:
super().__init__()
self.d_model: int = int(d_model)
self.d_cond: int = int(d_cond)
self.act: nn.SiLU = nn.SiLU()
self.proj: nn.Linear = nn.Linear(self.d_cond, 2 * self.d_model)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward_activated(self, act_cond: Tensor) -> Tensor:
"""Return packed modulation for a pre-activated conditioning vector."""
return self.proj(act_cond)
def forward(self, cond: Tensor) -> Tensor:
"""Return packed modulation [B, 2*d_model]."""
return self.forward_activated(self.act(cond))
class AdaLNScaleGateZeroLowRankDelta(nn.Module):
"""Low-rank delta for 2-way AdaLN: down(d_cond -> rank) -> up(rank -> 2*d_model).
Zero-initialized up projection preserves zero-output semantics at init.
"""
def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
super().__init__()
self.d_model: int = int(d_model)
self.d_cond: int = int(d_cond)
self.rank: int = int(rank)
self.down: nn.Linear = nn.Linear(self.d_cond, self.rank, bias=False)
self.up: nn.Linear = nn.Linear(self.rank, 2 * self.d_model, bias=False)
nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.up.weight)
def forward(self, act_cond: Tensor) -> Tensor:
"""Return packed delta modulation [B, 2*d_model]."""
return self.up(self.down(act_cond))