| |
| |
| """DFlash LoRA Training Script. |
| |
| Trains Qwen3-8B with LoRA adapters to learn 1-step parallel block generation |
| (dLLM capability). No separate target model is needed for hidden state extraction — |
| the LoRA model uses its own representations. |
| |
| Key differences from train_dflash.py: |
| - No DFlashTargetModel (no hidden state extraction) |
| - No TargetEmbeddingsAndHead (model uses its own embed/lm_head) |
| - DFlashLoRADraftModel: Qwen3-8B + PEFT LoRA |
| - OnlineDFlashLoRAModel: full-sequence DFlash attention mask |
| - Only LoRA parameters are trained; base model is frozen |
| - Saves LoRA adapter weights only |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import math |
| from contextlib import nullcontext |
| import os |
| import time |
| import warnings |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.distributed as dist |
| from accelerate.utils import set_seed |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
| from datasets import load_dataset |
| from specforge.args import TrackerArgs |
| from specforge.core.dflash_lora import OnlineDFlashLoRAModel |
| from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders |
| from specforge.distributed import destroy_distributed, get_dp_group, init_distributed |
| from specforge.modeling.draft.dflash_lora import DFlashLoRADraftModel |
| from specforge.optimizer import BF16Optimizer |
| from specforge.tracker import create_tracker |
| from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train DFlash LoRA (Qwen3-8B + LoRA)") |
|
|
| model_group = parser.add_argument_group("model") |
| model_group.add_argument("--model-path", type=str, required=True, |
| help="Path to Qwen3-8B (or any CausalLM) base model") |
| model_group.add_argument("--block-size", type=int, default=16) |
| model_group.add_argument("--mask-token-id", type=int, default=None, |
| help="MASK token ID. Auto-detected from tokenizer if not set.") |
| model_group.add_argument("--context-len", type=int, default=0, |
| help="Fixed context length before blocks. 0 = treat whole seq as blocks.") |
| model_group.add_argument("--trust-remote-code", action="store_true") |
| model_group.add_argument("--attn-implementation", type=str, default="sdpa", |
| choices=["sdpa", "eager"], |
| help="Attention backend for additive mask path. " |
| "Ignored when --attention-backend=flex_attention.") |
| model_group.add_argument("--attention-backend", type=str, default="flex_attention", |
| choices=["flex_attention", "additive"], |
| help="flex_attention: use BlockMask (zero extra memory). " |
| "additive: use 4D additive mask with SDPA/eager.") |
| model_group.add_argument("--lm-head-chunk-size", type=int, default=0, |
| help="Chunk size for chunked cross-entropy loss. " |
| "0 = full logits (default). 256-512 recommended to reduce VRAM.") |
| model_group.add_argument("--random-anchor", action="store_true", |
| help="Randomly sample anchor positions each step (like non-LoRA dflash).") |
| model_group.add_argument("--num-anchors", type=int, default=512, |
| help="Max number of random anchor positions per sample (default: 512).") |
|
|
| lora_group = parser.add_argument_group("lora") |
| lora_group.add_argument("--lora-rank", type=int, default=16) |
| lora_group.add_argument("--lora-alpha", type=int, default=32) |
| lora_group.add_argument("--lora-dropout", type=float, default=0.05) |
| lora_group.add_argument("--lora-target-modules", type=str, nargs="+", |
| default=["q_proj", "k_proj", "v_proj", "o_proj"], |
| help="Which modules to apply LoRA to") |
| lora_group.add_argument("--lora-config", type=str, default=None, |
| help="Path to JSON file with LoRA config (overrides individual args)") |
|
|
| dataset_group = parser.add_argument_group("dataset") |
| dataset_group.add_argument("--train-data-path", type=str, required=True) |
| dataset_group.add_argument("--eval-data-path", type=str, default=None) |
| dataset_group.add_argument("--chat-template", type=str, default="qwen") |
| dataset_group.add_argument("--is-preformatted", action="store_true") |
| dataset_group.add_argument("--dataloader-num-workers", type=int, default=8) |
| dataset_group.add_argument("--build-dataset-num-proc", type=int, |
| default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8))) |
|
|
| training_group = parser.add_argument_group("training") |
| training_group.add_argument("--num-epochs", type=int, default=3) |
| training_group.add_argument("--batch-size", type=int, default=1) |
| training_group.add_argument("--learning-rate", type=float, default=2e-4) |
| training_group.add_argument("--max-length", type=int, default=2048) |
| training_group.add_argument("--warmup-ratio", type=float, default=0.04) |
| training_group.add_argument("--max-grad-norm", type=float, default=1.0) |
| training_group.add_argument("--accumulation-steps", type=int, default=1) |
| training_group.add_argument("--loss-decay-gamma", type=float, default=None) |
| training_group.add_argument("--optimizer-type", type=str, default="adamw", |
| choices=["adamw", "adamw_8bit"]) |
| training_group.add_argument("--no-fp32-params", action="store_true") |
| training_group.add_argument("--gradient-checkpointing", action="store_true") |
| training_group.add_argument("--seed", type=int, default=42) |
| training_group.add_argument("--resume", action="store_true") |
| training_group.add_argument("--ckpt-dir", type=str, default=None) |
|
|
| output_group = parser.add_argument_group("output") |
| output_group.add_argument("--output-dir", type=str, required=True) |
| output_group.add_argument("--cache-dir", type=str, default="./cache") |
| output_group.add_argument("--log-interval", type=int, default=50) |
| output_group.add_argument("--eval-interval", type=int, default=1000) |
| output_group.add_argument("--save-interval", type=int, default=1000) |
|
|
| dist_group = parser.add_argument_group("distributed") |
| dist_group.add_argument("--dist-timeout", type=int, default=30) |
|
|
| tracker_group = parser.add_argument_group("tracker") |
| TrackerArgs.add_args(tracker_group) |
|
|
| return parser.parse_args() |
|
|
|
|
| def build_model(args) -> Tuple[DFlashLoRADraftModel, OnlineDFlashLoRAModel]: |
| """Load Qwen3-8B, inject LoRA, wrap in OnlineDFlashLoRAModel.""" |
| print_on_rank0(f"Loading base model from {args.model_path}") |
|
|
| |
| lora_rank = args.lora_rank |
| lora_alpha = args.lora_alpha |
| lora_dropout = args.lora_dropout |
| lora_target_modules = args.lora_target_modules |
|
|
| if args.lora_config is not None: |
| with open(args.lora_config) as f: |
| lora_cfg = json.load(f) |
| lora_rank = lora_cfg.get("lora_rank", lora_rank) |
| lora_alpha = lora_cfg.get("lora_alpha", lora_alpha) |
| lora_dropout = lora_cfg.get("lora_dropout", lora_dropout) |
| lora_target_modules = lora_cfg.get("lora_target_modules", lora_target_modules) |
| print_on_rank0(f"Loaded LoRA config from {args.lora_config}") |
|
|
| |
| if args.attention_backend == "flex_attention": |
| attn_impl = "flex_attention" |
| else: |
| attn_impl = args.attn_implementation |
|
|
| draft_model = DFlashLoRADraftModel.from_pretrained( |
| pretrained_model_name_or_path=args.model_path, |
| lora_rank=lora_rank, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| lora_target_modules=lora_target_modules, |
| block_size=args.block_size, |
| mask_token_id=args.mask_token_id or 151669, |
| torch_dtype=torch.bfloat16, |
| device_map="cuda", |
| trust_remote_code=args.trust_remote_code, |
| attn_implementation=attn_impl, |
| ) |
|
|
| online_model = OnlineDFlashLoRAModel( |
| draft_model=draft_model, |
| block_size=args.block_size, |
| mask_token_id=args.mask_token_id or 151669, |
| loss_decay_gamma=args.loss_decay_gamma, |
| attention_backend=args.attention_backend, |
| lm_head_chunk_size=args.lm_head_chunk_size, |
| random_anchor=args.random_anchor, |
| num_anchors=args.num_anchors, |
| ) |
|
|
| trainable = sum(p.numel() for p in draft_model.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in draft_model.parameters()) |
| print_on_rank0(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") |
|
|
| return draft_model, online_model |
|
|
|
|
| def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: |
| """Build train and eval dataloaders (same as train_dflash.py).""" |
| import hashlib |
|
|
| cache_params_string = ( |
| f"{args.train_data_path}-{args.max_length}-{args.chat_template}-{args.model_path}" |
| ) |
| cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() |
|
|
| rank = dist.get_rank() |
|
|
| |
| if os.path.isdir(args.train_data_path): |
| train_dataset = load_dataset(args.train_data_path, split="train") |
| else: |
| train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] |
|
|
| dataset_kwargs = dict( |
| dataset=train_dataset, |
| tokenizer=tokenizer, |
| chat_template=args.chat_template, |
| max_length=args.max_length, |
| is_preformatted=args.is_preformatted, |
| cache_dir=os.path.join(args.cache_dir, "processed_dataset"), |
| cache_key=cache_key, |
| num_proc=args.build_dataset_num_proc, |
| ) |
|
|
| |
| |
| |
| if rank == 0: |
| train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) |
| dist.barrier() |
| if rank != 0: |
| train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) |
|
|
| min_loss_tokens = 2 * args.block_size |
| original_size = len(train_eagle3_dataset) |
| train_eagle3_dataset = train_eagle3_dataset.filter( |
| lambda x: x["loss_mask"].sum() >= min_loss_tokens |
| ) |
| print_on_rank0(f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples") |
|
|
| train_dataloader = prepare_dp_dataloaders( |
| train_eagle3_dataset, |
| args.batch_size, |
| num_workers=args.dataloader_num_workers, |
| shuffle=True, |
| process_group=get_dp_group(), |
| ) |
|
|
| eval_dataloader = None |
| if args.eval_data_path: |
| if os.path.isdir(args.eval_data_path): |
| eval_dataset = load_dataset(args.eval_data_path, split="train") |
| else: |
| eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] |
| eval_eagle3_dataset = build_eagle3_dataset( |
| dataset=eval_dataset, |
| tokenizer=tokenizer, |
| chat_template=args.chat_template, |
| max_length=args.max_length, |
| is_preformatted=args.is_preformatted, |
| ) |
| eval_dataloader = prepare_dp_dataloaders( |
| eval_eagle3_dataset, |
| args.batch_size, |
| num_workers=args.dataloader_num_workers, |
| shuffle=False, |
| process_group=get_dp_group(), |
| ) |
|
|
| return train_dataloader, eval_dataloader |
|
|
|
|
| def save_checkpoint(args, epoch, step, online_model, draft_model, optimizer): |
| """Save LoRA adapter weights + training state.""" |
| save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") |
| if dist.get_rank() == 0: |
| os.makedirs(save_dir, exist_ok=True) |
| dist.barrier() |
|
|
| with FSDP.state_dict_type(online_model, StateDictType.FULL_STATE_DICT): |
| if dist.get_rank() == 0: |
| |
| draft_model.save_pretrained(save_dir) |
|
|
| torch.save( |
| { |
| "epoch": epoch, |
| "global_step": step, |
| "args": args, |
| **optimizer.state_dict(), |
| }, |
| os.path.join(save_dir, "training_state.pt"), |
| ) |
| print_on_rank0(f"Saved LoRA checkpoint to {save_dir}") |
|
|
| dist.barrier() |
|
|
|
|
| def record_metrics(args, loss, accuracy, global_step, tracker, optimizer, |
| train_dataloader=None, mode="train"): |
| logdict = {} |
| if mode == "train" and optimizer is not None: |
| logdict["train/lr"] = optimizer.get_learning_rate() |
| logdict[f"{mode}/loss"] = loss |
| logdict[f"{mode}/accuracy"] = accuracy |
| print_on_rank0( |
| f"{mode.capitalize()} - Step {global_step}, Loss: {loss:.4f}, Acc: {accuracy:.4f}" |
| ) |
| tracker.log(logdict, step=global_step) |
|
|
|
|
| def main(): |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| warnings.filterwarnings( |
| "ignore", |
| "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", |
| ) |
|
|
| args = parse_args() |
| set_seed(args.seed) |
|
|
| |
| init_distributed(timeout=args.dist_timeout, tp_size=1) |
| print_with_rank("Initialized distributed") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| if args.mask_token_id is not None: |
| mask_token_id = args.mask_token_id |
| elif tokenizer.mask_token_id is not None: |
| mask_token_id = tokenizer.mask_token_id |
| else: |
| tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) |
| mask_token_id = tokenizer.mask_token_id |
| print_on_rank0(f"Using mask_token_id: {mask_token_id}") |
| args.mask_token_id = mask_token_id |
|
|
| draft_model, online_model = build_model(args) |
|
|
| |
| draft_model.mask_token_id = mask_token_id |
| online_model.mask_token_id = mask_token_id |
|
|
| if args.gradient_checkpointing: |
| draft_model.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| print_on_rank0("Gradient checkpointing enabled") |
|
|
| |
| resume_state = None |
| if args.ckpt_dir is not None: |
| if os.path.isdir(args.ckpt_dir): |
| print_on_rank0(f"Loading LoRA weights from {args.ckpt_dir}") |
| from peft import PeftModel |
| draft_model.model = PeftModel.from_pretrained( |
| draft_model.model.base_model.model, args.ckpt_dir |
| ) |
| else: |
| raise ValueError(f"ckpt_dir {args.ckpt_dir} is not a valid directory") |
|
|
| if args.resume and os.path.isdir(args.output_dir): |
| last_ckpt = get_last_checkpoint(args.output_dir, prefix=r"epoch_\d+_step") |
| if last_ckpt: |
| print_on_rank0(f"Resuming from {last_ckpt}") |
| from peft import PeftModel |
| draft_model.model = PeftModel.from_pretrained( |
| draft_model.model.base_model.model, last_ckpt |
| ) |
| training_state_path = os.path.join(last_ckpt, "training_state.pt") |
| if os.path.exists(training_state_path): |
| resume_state = torch.load(training_state_path, map_location="cpu", weights_only=False) |
| print_on_rank0( |
| f"Will resume from epoch {resume_state['epoch']}, step {resume_state['global_step']}" |
| ) |
|
|
| train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) |
|
|
| steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) |
| total_steps = args.num_epochs * steps_per_epoch |
| print_on_rank0(f"Total training steps: {total_steps}") |
|
|
| |
| online_model = FSDP( |
| online_model, |
| use_orig_params=True, |
| mixed_precision=MixedPrecision( |
| param_dtype=torch.bfloat16, |
| buffer_dtype=torch.bfloat16, |
| ), |
| sharding_strategy=ShardingStrategy.NO_SHARD, |
| ) |
| print_with_rank("Initialized FSDP") |
|
|
| optimizer = BF16Optimizer( |
| draft_model, |
| lr=args.learning_rate, |
| max_grad_norm=args.max_grad_norm, |
| warmup_ratio=args.warmup_ratio, |
| total_steps=total_steps, |
| use_fp32_params=not args.no_fp32_params, |
| optimizer_type=args.optimizer_type, |
| ) |
|
|
| start_epoch = 0 |
| global_step = 0 |
| if resume_state is not None: |
| optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) |
| start_epoch = resume_state["epoch"] |
| global_step = resume_state["global_step"] |
| del resume_state |
| print_on_rank0(f"Restored scheduler, lr={optimizer.get_learning_rate():.6f}") |
|
|
| skip_steps = global_step - start_epoch * len(train_dataloader) |
|
|
| tracker = create_tracker(args, args.output_dir) |
| last_time = time.time() |
| print_on_rank0(f"Starting training from epoch {start_epoch}, step {global_step}") |
|
|
| for epoch in range(start_epoch, args.num_epochs): |
| train_dataloader.sampler.set_epoch(epoch) |
| draft_model.train() |
|
|
| if dist.get_rank() == 0: |
| progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}", leave=True) |
| else: |
| progress_bar = train_dataloader |
|
|
| for step_in_epoch, data in enumerate(progress_bar): |
| if epoch == start_epoch and step_in_epoch < skip_steps: |
| continue |
| global_step += 1 |
|
|
| input_ids = data["input_ids"].cuda() |
| attention_mask = data["attention_mask"].cuda() |
| loss_mask = data["loss_mask"].cuda() |
|
|
| |
| is_accumulation_step = (global_step % args.accumulation_steps) != 0 |
| ctx = online_model.no_sync() if is_accumulation_step else nullcontext() |
|
|
| with ctx: |
| loss, accuracy = online_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| loss_mask=loss_mask, |
| context_len=args.context_len, |
| ) |
| (loss / args.accumulation_steps).backward() |
|
|
| if global_step % args.accumulation_steps == 0: |
| optimizer.step() |
|
|
| if global_step % args.log_interval == 0: |
| loss_val = loss.item() |
| acc_val = accuracy.item() |
| loss_t = torch.tensor(loss_val, device="cuda") |
| acc_t = torch.tensor(acc_val, device="cuda") |
| dist.all_reduce(loss_t) |
| dist.all_reduce(acc_t) |
| record_metrics(args, loss_t.item() / dist.get_world_size(), |
| acc_t.item() / dist.get_world_size(), global_step, |
| tracker, optimizer, train_dataloader, mode="train") |
|
|
| if dist.get_rank() == 0: |
| elapsed = time.time() - last_time |
| last_time = time.time() |
| progress_bar.set_postfix({ |
| "loss": f"{loss.item():.4f}", |
| "acc": f"{accuracy.item():.4f}", |
| "iter_time": f"{elapsed:.2f}s", |
| }) |
|
|
| if global_step % args.save_interval == 0: |
| save_checkpoint(args, epoch, global_step, online_model, draft_model, optimizer) |
|
|
| save_checkpoint(args, args.num_epochs, global_step, online_model, draft_model, optimizer) |
| tracker.close() |
| destroy_distributed() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|