VibeVoice-Realtime-0.5B / vibevoice /modular /modular_vibevoice_diffusion_head.py
akhaliq's picture
akhaliq HF Staff
Upload 37 files
26e0cd3 verified
raw
history blame
9.22 kB
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.modeling_utils import PreTrainedModel
# from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.activations import ACT2FN
from transformers.utils import logging
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
logger = logging.get_logger(__name__)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
super().__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter('weight', None)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output
def extra_repr(self) -> str:
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
def modulate(x, shift, scale):
"""Apply modulation to input tensor."""
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
Args:
hidden_size (`int`): Size of the output embedding
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(hidden_size, hidden_size, bias=False),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
These may be fractional.
dim (`int`): The dimension of the output.
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
Returns:
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding.to(t.dtype)
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class FeedForwardNetwork(nn.Module):
"""
Standard feed-forward network with SwiGLU activation.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
"""
def __init__(
self,
embed_dim,
ffn_dim,
):
super().__init__()
self.embed_dim = embed_dim
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
def forward(self, x):
gate = self.gate_proj(x)
up = self.up_proj(x)
# SwiGLU activation
# gate = F.silu(gate)
gate = self.act_fn(gate)
return self.down_proj(gate * up)
class HeadLayer(nn.Module):
"""
A layer in the diffusion head.
Args:
embed_dim (`int`): Input dimension
ffn_dim (`int`): Hidden dimension
cond_dim (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(
self,
embed_dim,
ffn_dim,
cond_dim,
norm_eps=1e-5,
):
super().__init__()
self.embed_dim = embed_dim
self.cond_dim = cond_dim
self.ffn_dim = ffn_dim
self.ffn = FeedForwardNetwork(
self.embed_dim,
self.ffn_dim,
)
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
)
def forward(self, x, c):
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
return x
class FinalLayer(nn.Module):
"""
Final layer in the diffusion head.
Args:
hidden_size (`int`): Input dimension
output_size (`int`): Output dimension
cond_size (`int`): Condition embedding dimension
norm_eps (`float`, optional): Epsilon for normalization
"""
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
super().__init__()
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
self.linear = nn.Linear(hidden_size, output_size, bias=False)
self.adaLN_modulation = nn.Sequential(
# nn.SiLU(),
ACT2FN['silu'],
nn.Linear(cond_size, 2 * hidden_size, bias=False)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class VibeVoiceDiffusionHead(PreTrainedModel):
"""
Diffusion head model for vibevoice.
Args:
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
"""
config_class = VibeVoiceDiffusionHeadConfig
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(
self,
config,
):
super().__init__(config)
self.config = config
self.cond_dim = config.hidden_size
latent_size = config.latent_size
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
self.t_embedder = TimestepEmbedder(self.cond_dim)
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
# Create the intermediate layers
self.layers = nn.ModuleList([
HeadLayer(
embed_dim=config.hidden_size,
ffn_dim=ffn_dim,
cond_dim=self.cond_dim,
norm_eps=config.rms_norm_eps
)
for _ in range(config.head_layers)
])
# Final layer for output
self.final_layer = FinalLayer(
hidden_size=config.hidden_size,
output_size=latent_size,
cond_size=self.cond_dim,
norm_eps=config.rms_norm_eps
)
self.initialize_weights()
def initialize_weights(self):
"""Initialize the weights of the model."""
# Initialize timestep embedder
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers
for layer in self.layers:
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
# Zero-out output layers
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
def forward(
self,
noisy_images,
timesteps,
condition,
):
"""
Forward pass of the prediction head.
Args:
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
timesteps (`torch.Tensor`): Timesteps for diffusion
condition (`torch.Tensor`): Conditioning information
Returns:
`torch.Tensor`: The predicted noise/velocity
"""
x = self.noisy_images_proj(noisy_images)
t = self.t_embedder(timesteps)
condition = self.cond_proj(condition)
c = condition + t
for layer in self.layers:
x = layer(x, c)
x = self.final_layer(x, c)
return x
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
__all__ = [
"VibeVoiceDiffusionHead",
]