File size: 1,383 Bytes
94c2704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn


# -------------------------
# Timestep embeddings
# -------------------------

class GaussianFourierProjection(nn.Module):
    """
    Gaussian Fourier features for continuous time t in [0, 1].
    Produces 2 * embed_dim features: [sin(W t), cos(W t)].
    """
    def __init__(self, embed_dim, scale):
        super().__init__()
        assert embed_dim % 2 == 0, "embed_dim must be even."
        self.embed_dim = embed_dim
        self.register_buffer("W", torch.randn(embed_dim // 2) * scale, persistent=False)  # Fixed random frequencies

    def forward(self, t):
        # Ensure float
        t = t.float().unsqueeze(-1)  # Broadcoast to [B, 1]
        angles = t * self.W  # B, embed_dim // 2
        return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)


class TimeEmbedding(nn.Module):
    def __init__(self, hidden_dim, fourier_dim, scale):
        super().__init__()
        assert fourier_dim % 2 == 0, "fourier_dim must be even for sine/cosine pairs."

        self.fourier = GaussianFourierProjection(fourier_dim, scale)
        self.mlp = nn.Sequential(
            nn.Linear(fourier_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, t):
        ft = self.fourier(t)        # (B, fourier_dim)
        return self.mlp(ft)         # (B, hidden_dim)