Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
# coding=utf-8
"""DFlash Training Wrapper."""
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as grad_checkpoint
from specforge.modeling.draft.dflash import DFlashDraftModel
try:
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
FLEX_ATTENTION_AVAILABLE = True
except ImportError:
FLEX_ATTENTION_AVAILABLE = False
BlockMask = None
create_block_mask = None
class OnlineDFlashModel(nn.Module):
"""DFlash online training wrapper with block-wise CE loss."""
def __init__(
self,
draft_model: DFlashDraftModel,
target_lm_head: nn.Module,
target_embed_tokens: nn.Module,
mask_token_id: int,
block_size: int = 16,
attention_backend: str = "flex_attention",
random_anchor: bool = False,
num_anchors: int = 512,
loss_decay_gamma: Optional[float] = None,
lm_head_chunk_size: int = 0,
):
super().__init__()
self.draft_model = draft_model
self.lm_head = target_lm_head
self.embed_tokens = target_embed_tokens
self.block_size = block_size
self.mask_token_id = mask_token_id
self.attention_backend = attention_backend
self.random_anchor = random_anchor
self.num_anchors = num_anchors
self.loss_decay_gamma = loss_decay_gamma
self.lm_head_chunk_size = lm_head_chunk_size
self._cached_block_mask: Optional[BlockMask] = None
self._cached_seq_len: Optional[int] = None
self._cached_bsz: Optional[int] = None
def _sample_anchor_positions(
self, seq_len: int, loss_mask: torch.Tensor, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Randomly sample anchor positions per sample; returns (anchors, keep_mask)."""
bs = self.block_size
bsz = loss_mask.shape[0]
max_anchor = max(seq_len - bs, 0)
valid = loss_mask[:, : max_anchor + 1] > 0.5
valid_counts = valid.sum(dim=1)
max_n = min(self.num_anchors, int(valid_counts.max().item()))
if max_n == 0:
anchors = torch.arange(0, seq_len, bs, device=device)
anchors = anchors.unsqueeze(0).expand(bsz, -1)
return anchors, torch.ones(
bsz, anchors.shape[1], dtype=torch.bool, device=device
)
anchor_list = []
keep_list = []
for i in range(bsz):
valid_indices = valid[i].nonzero(as_tuple=False).squeeze(-1)
n_i = min(self.num_anchors, valid_indices.numel())
if n_i == 0:
anchors_i = torch.zeros(max_n, dtype=torch.long, device=device)
keep_i = torch.zeros(max_n, dtype=torch.bool, device=device)
else:
perm = torch.randperm(valid_indices.numel(), device=device)[:n_i]
anchors_i = valid_indices[perm].sort().values
if n_i < max_n:
anchors_i = torch.cat(
[anchors_i, anchors_i[-1:].expand(max_n - n_i)], dim=0
)
keep_i = torch.zeros(max_n, dtype=torch.bool, device=device)
keep_i[:n_i] = True
anchor_list.append(anchors_i)
keep_list.append(keep_i)
return torch.stack(anchor_list, dim=0), torch.stack(keep_list, dim=0)
def _build_blocks_from_anchors(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
anchor_positions: torch.Tensor,
block_keep_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Gather fixed-size blocks; padding blocks get block_id=-1 and loss=0."""
bs = self.block_size
device = input_ids.device
bsz = input_ids.shape[0]
n = anchor_positions.shape[1]
offsets = torch.arange(bs, device=device).unsqueeze(0)
gather_idx = anchor_positions.unsqueeze(-1) + offsets
gather_idx = gather_idx.reshape(bsz, -1)
block_input_ids = torch.gather(input_ids, 1, gather_idx)
block_hidden = torch.gather(
hidden_states,
1,
gather_idx.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
)
block_loss_mask = torch.gather(loss_mask, 1, gather_idx)
token_keep = block_keep_mask.repeat_interleave(bs, dim=1)
block_loss_mask = block_loss_mask * token_keep.to(block_loss_mask.dtype)
block_ids = torch.arange(n, device=device).repeat_interleave(bs)
pad_token_mask = (~block_keep_mask).repeat_interleave(bs, dim=1)
block_ids = block_ids.unsqueeze(0).expand(bsz, -1).clone()
block_ids[pad_token_mask] = -1
return block_input_ids, block_hidden, block_loss_mask, block_ids, gather_idx
def prepare_noise_input(
self, input_ids: torch.Tensor, block_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Prepare noise input: first token of each block is real, rest are MASK."""
bsz, seq_len = input_ids.shape
device = input_ids.device
if block_ids is not None:
is_block_start = torch.ones(bsz, seq_len, dtype=torch.bool, device=device)
is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1]
else:
positions = torch.arange(seq_len, device=device)
is_block_start = (positions % self.block_size) == 0
is_block_start = is_block_start.unsqueeze(0).expand(bsz, -1)
noise_input_ids = torch.full_like(input_ids, self.mask_token_id)
noise_input_ids[is_block_start] = input_ids[is_block_start]
return noise_input_ids
def _get_or_create_block_mask(
self,
bsz: int,
q_len: int,
kv_len: int,
device: torch.device,
block_ids: Optional[torch.Tensor] = None,
) -> "BlockMask":
"""Get cached BlockMask or create a new one."""
if block_ids is None:
if (
self._cached_block_mask is not None
and self._cached_seq_len == q_len
and self._cached_bsz == bsz
):
return self._cached_block_mask
block_size = self.block_size
if block_ids is not None:
_block_ids = block_ids
def dflash_mask_fn(b, h, q_idx, kv_idx):
L = q_len
is_ctx = kv_idx < L
q_b = _block_ids[b, q_idx]
k_ctx = _block_ids[b, kv_idx.clamp(max=L - 1)]
k_noise = _block_ids[b, (kv_idx - L).clamp(min=0, max=L - 1)]
q_valid = q_b >= 0
k_ctx_valid = k_ctx >= 0
k_noise_valid = k_noise >= 0
ctx_visible = is_ctx & q_valid & k_ctx_valid & (k_ctx < q_b)
noise_visible = (~is_ctx) & q_valid & k_noise_valid & (k_noise == q_b)
return ctx_visible | noise_visible
else:
def dflash_mask_fn(b, h, q_idx, kv_idx):
L = q_len
is_ctx = kv_idx < L
q_block = q_idx // block_size
k_block_ctx = kv_idx // block_size
k_block_noise = (kv_idx - L) // block_size
ctx_visible = is_ctx & (k_block_ctx < q_block)
noise_visible = (~is_ctx) & (k_block_noise == q_block)
return ctx_visible | noise_visible
block_mask = create_block_mask(
dflash_mask_fn,
B=bsz,
H=1,
Q_LEN=q_len,
KV_LEN=kv_len,
device=device,
)
if block_ids is None:
self._cached_block_mask = block_mask
self._cached_seq_len = q_len
self._cached_bsz = bsz
return block_mask
def _create_parallel_attention_mask(
self,
bsz: int,
seq_len: int,
device: torch.device,
block_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Create [bsz, L, 2L] attention mask for parallel training."""
if block_ids is None:
ids = torch.arange(seq_len, device=device) // self.block_size
q_ids = ids.unsqueeze(1)
k_ids = ids.unsqueeze(0)
ctx_mask = k_ids < q_ids
noise_mask = q_ids == k_ids
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1)
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
return full_mask.unsqueeze(0).expand(bsz, -1, -1)
q_ids = block_ids.unsqueeze(2)
k_ids = block_ids.unsqueeze(1)
q_valid = q_ids >= 0
k_valid = k_ids >= 0
ctx_mask = q_valid & k_valid & (k_ids < q_ids)
noise_mask = q_valid & k_valid & (k_ids == q_ids)
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=2)
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
return full_mask
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Parallel block-wise training forward pass."""
bsz, seq_len = input_ids.shape
device = input_ids.device
block_ids = None
if self.random_anchor and self.training:
anchor_positions, block_keep_mask = self._sample_anchor_positions(
seq_len, loss_mask, device
)
(input_ids, hidden_states, loss_mask, block_ids, block_positions) = (
self._build_blocks_from_anchors(
input_ids,
hidden_states,
loss_mask,
anchor_positions,
block_keep_mask,
)
)
effective_len = input_ids.shape[1]
base_positions = block_positions
else:
n_blocks = seq_len // self.block_size
effective_len = n_blocks * self.block_size
input_ids = input_ids[:, :effective_len]
hidden_states = hidden_states[:, :effective_len, :]
loss_mask = loss_mask[:, :effective_len]
attention_mask = attention_mask[:, :effective_len]
base_positions = (
torch.arange(effective_len, device=device).unsqueeze(0).expand(bsz, -1)
)
noise_input_ids = self.prepare_noise_input(input_ids, block_ids)
noise_embedding = self.embed_tokens(noise_input_ids)
position_ids = torch.cat([base_positions, base_positions], dim=1)
if (
self.attention_backend == "flex_attention"
and FLEX_ATTENTION_AVAILABLE
and create_block_mask is not None
):
dflash_attn_mask = self._get_or_create_block_mask(
bsz=bsz,
q_len=effective_len,
kv_len=effective_len * 2,
device=device,
block_ids=block_ids,
)
else:
dflash_attn_mask = self._create_parallel_attention_mask(
bsz, effective_len, device, block_ids
)
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
dflash_attn_mask = dflash_attn_mask.unsqueeze(1)
hidden = self.draft_model(
position_ids=position_ids,
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
)
dflash_loss_weights = create_dflash_loss_mask(
effective_len,
self.block_size,
device,
gamma=self.loss_decay_gamma,
block_ids=block_ids,
)
if block_ids is None:
dflash_loss_weights = dflash_loss_weights.unsqueeze(0)
combined_mask = loss_mask * dflash_loss_weights
if self.lm_head_chunk_size > 0 and effective_len > self.lm_head_chunk_size:
loss, accuracy = self._chunked_lm_loss(
hidden, input_ids, loss_mask, combined_mask, effective_len, block_ids
)
else:
loss, accuracy = self._full_lm_loss(
hidden, input_ids, loss_mask, combined_mask, effective_len, block_ids
)
return loss, accuracy
def _compute_acceptance_accuracy(
self,
preds_all: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
effective_len: int,
block_ids: Optional[torch.Tensor],
) -> torch.Tensor:
"""Compute block-wise acceptance rate metric."""
bsz = input_ids.shape[0]
correct_all = (preds_all == input_ids).float()
bs = self.block_size
n_blocks = effective_len // bs
try:
if block_ids is not None:
correct_blocks = correct_all.reshape(bsz, n_blocks, bs)
loss_mask_blocks = loss_mask.reshape(bsz, n_blocks, bs)
else:
if n_blocks > 1:
correct_blocks = correct_all[:, bs:].reshape(
bsz, n_blocks - 1, bs
)
loss_mask_blocks = loss_mask[:, bs:].reshape(
bsz, n_blocks - 1, bs
)
else:
raise ValueError("Only one block")
correct_pred = correct_blocks[:, :, 1:]
loss_mask_pred = loss_mask_blocks[:, :, 1:]
block_valid = (loss_mask_pred.sum(dim=2) == (bs - 1)).float()
correct_pred = correct_pred * loss_mask_pred
cumulative_correct = correct_pred.cumprod(dim=2)
acceptance_lengths = cumulative_correct.sum(dim=2)
acceptance_lengths = (acceptance_lengths * block_valid).sum(dim=1)
total_blocks_sum = block_valid.sum(dim=1).sum().clamp_min(1)
avg_accept_length = acceptance_lengths.sum() / total_blocks_sum
accuracy = avg_accept_length / (bs - 1)
except Exception:
valid_mask = (loss_mask > 0.5).reshape(-1)
correct_flat = correct_all.reshape(-1)[valid_mask]
accuracy = correct_flat.mean() if correct_flat.numel() > 0 else 0.0
return accuracy
def _full_lm_loss(
self,
hidden: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
combined_mask: torch.Tensor,
effective_len: int,
block_ids: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Original non-chunked lm_head + loss computation."""
logits = self.lm_head(hidden)
with torch.no_grad():
preds_all = logits.argmax(dim=-1)
accuracy = self._compute_acceptance_accuracy(
preds_all, input_ids, loss_mask, effective_len, block_ids
)
logits_flat = logits.reshape(-1, logits.size(-1))
labels_flat = input_ids.reshape(-1)
mask_flat = combined_mask.reshape(-1)
active_indices = mask_flat > 1e-6
active_logits = logits_flat[active_indices]
active_labels = labels_flat[active_indices]
active_weights = mask_flat[active_indices]
if self.loss_decay_gamma is not None:
per_token_loss = F.cross_entropy(
active_logits, active_labels, reduction="none"
)
loss = (per_token_loss * active_weights).sum() / active_weights.sum()
else:
loss = F.cross_entropy(active_logits, active_labels)
return loss, accuracy
def _chunked_lm_loss(
self,
hidden: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
combined_mask: torch.Tensor,
effective_len: int,
block_ids: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Chunked lm_head + loss: avoids materializing full [bsz, seq, vocab] logits.
Processes the sequence in chunks of lm_head_chunk_size. Each chunk uses
gradient checkpointing so logits are recomputed (not stored) during backward.
Peak logits memory: O(chunk_size * vocab_size) instead of O(seq_len * vocab_size).
"""
chunk_size = self.lm_head_chunk_size
# 1. Accuracy: compute argmax per chunk (no_grad, no memory concern)
with torch.no_grad():
preds_chunks = []
for start in range(0, effective_len, chunk_size):
end = min(start + chunk_size, effective_len)
chunk_logits = self.lm_head(hidden[:, start:end, :])
preds_chunks.append(chunk_logits.argmax(dim=-1))
preds_all = torch.cat(preds_chunks, dim=1)
accuracy = self._compute_acceptance_accuracy(
preds_all, input_ids, loss_mask, effective_len, block_ids
)
# 2. Loss: chunked with gradient checkpointing
total_loss = torch.tensor(0.0, device=hidden.device)
total_weight = torch.tensor(0.0, device=hidden.device)
def _chunk_ce(h_chunk, labels_chunk, weights_chunk):
logits_chunk = self.lm_head(h_chunk)
logits_flat = logits_chunk.reshape(-1, logits_chunk.size(-1))
labels_flat = labels_chunk.reshape(-1)
weights_flat = weights_chunk.reshape(-1)
active = weights_flat > 1e-6
if not active.any():
return logits_flat.sum() * 0, weights_flat.sum() * 0
per_token = F.cross_entropy(
logits_flat[active], labels_flat[active], reduction="none"
)
return (per_token * weights_flat[active]).sum(), weights_flat[active].sum()
for start in range(0, effective_len, chunk_size):
end = min(start + chunk_size, effective_len)
chunk_loss, chunk_weight = grad_checkpoint(
_chunk_ce,
hidden[:, start:end, :],
input_ids[:, start:end],
combined_mask[:, start:end],
use_reentrant=False,
)
total_loss = total_loss + chunk_loss
total_weight = total_weight + chunk_weight
loss = total_loss / total_weight.clamp_min(1e-8)
return loss, accuracy
def create_dflash_loss_mask(
seq_len: int,
block_size: int,
device: torch.device,
gamma: Optional[float] = None,
block_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Create DFlash loss mask: excludes block starts; for non-random, also excludes block 0.
Returns [seq_len] when block_ids is None, [bsz, seq_len] when block_ids is per-sample.
"""
positions = torch.arange(seq_len, device=device)
pos_in_block = positions % block_size
if block_ids is not None:
is_block_start = torch.ones_like(block_ids, dtype=torch.bool)
is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1]
valid_mask = ~is_block_start & (block_ids >= 0)
pos_in_block = pos_in_block.unsqueeze(0)
else:
is_block_start = (positions % block_size) == 0
is_first_block = (positions // block_size) == 0
valid_mask = ~is_first_block & ~is_block_start
if gamma is not None:
decay = torch.exp(-(pos_in_block.float() - 1.0) / gamma)
return valid_mask.float() * decay
else:
return valid_mask.float()