#!/usr/bin/env python3 # coding=utf-8 """DFlash LoRA Training Script. Trains Qwen3-8B with LoRA adapters to learn 1-step parallel block generation (dLLM capability). No separate target model is needed for hidden state extraction — the LoRA model uses its own representations. Key differences from train_dflash.py: - No DFlashTargetModel (no hidden state extraction) - No TargetEmbeddingsAndHead (model uses its own embed/lm_head) - DFlashLoRADraftModel: Qwen3-8B + PEFT LoRA - OnlineDFlashLoRAModel: full-sequence DFlash attention mask - Only LoRA parameters are trained; base model is frozen - Saves LoRA adapter weights only """ import argparse import json import logging import math from contextlib import nullcontext import os import time import warnings from typing import Optional, Tuple import torch import torch.distributed as dist from accelerate.utils import set_seed from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoTokenizer from datasets import load_dataset from specforge.args import TrackerArgs from specforge.core.dflash_lora import OnlineDFlashLoRAModel from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders from specforge.distributed import destroy_distributed, get_dp_group, init_distributed from specforge.modeling.draft.dflash_lora import DFlashLoRADraftModel from specforge.optimizer import BF16Optimizer from specforge.tracker import create_tracker from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank def parse_args(): parser = argparse.ArgumentParser(description="Train DFlash LoRA (Qwen3-8B + LoRA)") model_group = parser.add_argument_group("model") model_group.add_argument("--model-path", type=str, required=True, help="Path to Qwen3-8B (or any CausalLM) base model") model_group.add_argument("--block-size", type=int, default=16) model_group.add_argument("--mask-token-id", type=int, default=None, help="MASK token ID. Auto-detected from tokenizer if not set.") model_group.add_argument("--context-len", type=int, default=0, help="Fixed context length before blocks. 0 = treat whole seq as blocks.") model_group.add_argument("--trust-remote-code", action="store_true") model_group.add_argument("--attn-implementation", type=str, default="sdpa", choices=["sdpa", "eager"], help="Attention backend for additive mask path. " "Ignored when --attention-backend=flex_attention.") model_group.add_argument("--attention-backend", type=str, default="flex_attention", choices=["flex_attention", "additive"], help="flex_attention: use BlockMask (zero extra memory). " "additive: use 4D additive mask with SDPA/eager.") model_group.add_argument("--lm-head-chunk-size", type=int, default=0, help="Chunk size for chunked cross-entropy loss. " "0 = full logits (default). 256-512 recommended to reduce VRAM.") model_group.add_argument("--random-anchor", action="store_true", help="Randomly sample anchor positions each step (like non-LoRA dflash).") model_group.add_argument("--num-anchors", type=int, default=512, help="Max number of random anchor positions per sample (default: 512).") lora_group = parser.add_argument_group("lora") lora_group.add_argument("--lora-rank", type=int, default=16) lora_group.add_argument("--lora-alpha", type=int, default=32) lora_group.add_argument("--lora-dropout", type=float, default=0.05) lora_group.add_argument("--lora-target-modules", type=str, nargs="+", default=["q_proj", "k_proj", "v_proj", "o_proj"], help="Which modules to apply LoRA to") lora_group.add_argument("--lora-config", type=str, default=None, help="Path to JSON file with LoRA config (overrides individual args)") dataset_group = parser.add_argument_group("dataset") dataset_group.add_argument("--train-data-path", type=str, required=True) dataset_group.add_argument("--eval-data-path", type=str, default=None) dataset_group.add_argument("--chat-template", type=str, default="qwen") dataset_group.add_argument("--is-preformatted", action="store_true") dataset_group.add_argument("--dataloader-num-workers", type=int, default=8) dataset_group.add_argument("--build-dataset-num-proc", type=int, default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8))) training_group = parser.add_argument_group("training") training_group.add_argument("--num-epochs", type=int, default=3) training_group.add_argument("--batch-size", type=int, default=1) training_group.add_argument("--learning-rate", type=float, default=2e-4) training_group.add_argument("--max-length", type=int, default=2048) training_group.add_argument("--warmup-ratio", type=float, default=0.04) training_group.add_argument("--max-grad-norm", type=float, default=1.0) training_group.add_argument("--accumulation-steps", type=int, default=1) training_group.add_argument("--loss-decay-gamma", type=float, default=None) training_group.add_argument("--optimizer-type", type=str, default="adamw", choices=["adamw", "adamw_8bit"]) training_group.add_argument("--no-fp32-params", action="store_true") training_group.add_argument("--gradient-checkpointing", action="store_true") training_group.add_argument("--seed", type=int, default=42) training_group.add_argument("--resume", action="store_true") training_group.add_argument("--ckpt-dir", type=str, default=None) output_group = parser.add_argument_group("output") output_group.add_argument("--output-dir", type=str, required=True) output_group.add_argument("--cache-dir", type=str, default="./cache") output_group.add_argument("--log-interval", type=int, default=50) output_group.add_argument("--eval-interval", type=int, default=1000) output_group.add_argument("--save-interval", type=int, default=1000) dist_group = parser.add_argument_group("distributed") dist_group.add_argument("--dist-timeout", type=int, default=30) tracker_group = parser.add_argument_group("tracker") TrackerArgs.add_args(tracker_group) return parser.parse_args() def build_model(args) -> Tuple[DFlashLoRADraftModel, OnlineDFlashLoRAModel]: """Load Qwen3-8B, inject LoRA, wrap in OnlineDFlashLoRAModel.""" print_on_rank0(f"Loading base model from {args.model_path}") # Load LoRA config from JSON if provided lora_rank = args.lora_rank lora_alpha = args.lora_alpha lora_dropout = args.lora_dropout lora_target_modules = args.lora_target_modules if args.lora_config is not None: with open(args.lora_config) as f: lora_cfg = json.load(f) lora_rank = lora_cfg.get("lora_rank", lora_rank) lora_alpha = lora_cfg.get("lora_alpha", lora_alpha) lora_dropout = lora_cfg.get("lora_dropout", lora_dropout) lora_target_modules = lora_cfg.get("lora_target_modules", lora_target_modules) print_on_rank0(f"Loaded LoRA config from {args.lora_config}") # Resolve attn_implementation: flex_attention backend uses HF flex_attention impl if args.attention_backend == "flex_attention": attn_impl = "flex_attention" else: attn_impl = args.attn_implementation # sdpa or eager draft_model = DFlashLoRADraftModel.from_pretrained( pretrained_model_name_or_path=args.model_path, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_target_modules=lora_target_modules, block_size=args.block_size, mask_token_id=args.mask_token_id or 151669, # placeholder, updated after tokenizer load torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=args.trust_remote_code, attn_implementation=attn_impl, ) online_model = OnlineDFlashLoRAModel( draft_model=draft_model, block_size=args.block_size, mask_token_id=args.mask_token_id or 151669, loss_decay_gamma=args.loss_decay_gamma, attention_backend=args.attention_backend, lm_head_chunk_size=args.lm_head_chunk_size, random_anchor=args.random_anchor, num_anchors=args.num_anchors, ) trainable = sum(p.numel() for p in draft_model.parameters() if p.requires_grad) total = sum(p.numel() for p in draft_model.parameters()) print_on_rank0(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") return draft_model, online_model def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: """Build train and eval dataloaders (same as train_dflash.py).""" import hashlib cache_params_string = ( f"{args.train_data_path}-{args.max_length}-{args.chat_template}-{args.model_path}" ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() rank = dist.get_rank() # Support both jsonl and parquet/directory formats if os.path.isdir(args.train_data_path): train_dataset = load_dataset(args.train_data_path, split="train") else: train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] dataset_kwargs = dict( dataset=train_dataset, tokenizer=tokenizer, chat_template=args.chat_template, max_length=args.max_length, is_preformatted=args.is_preformatted, cache_dir=os.path.join(args.cache_dir, "processed_dataset"), cache_key=cache_key, num_proc=args.build_dataset_num_proc, ) # Only rank 0 runs the expensive .map() preprocessing. # Other ranks wait, then load directly from the cached result. # This avoids N_GPU × num_proc concurrent workers causing OOM. if rank == 0: train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) dist.barrier() if rank != 0: train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) min_loss_tokens = 2 * args.block_size original_size = len(train_eagle3_dataset) train_eagle3_dataset = train_eagle3_dataset.filter( lambda x: x["loss_mask"].sum() >= min_loss_tokens ) print_on_rank0(f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples") train_dataloader = prepare_dp_dataloaders( train_eagle3_dataset, args.batch_size, num_workers=args.dataloader_num_workers, shuffle=True, process_group=get_dp_group(), ) eval_dataloader = None if args.eval_data_path: if os.path.isdir(args.eval_data_path): eval_dataset = load_dataset(args.eval_data_path, split="train") else: eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] eval_eagle3_dataset = build_eagle3_dataset( dataset=eval_dataset, tokenizer=tokenizer, chat_template=args.chat_template, max_length=args.max_length, is_preformatted=args.is_preformatted, ) eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, args.batch_size, num_workers=args.dataloader_num_workers, shuffle=False, process_group=get_dp_group(), ) return train_dataloader, eval_dataloader def save_checkpoint(args, epoch, step, online_model, draft_model, optimizer): """Save LoRA adapter weights + training state.""" save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") if dist.get_rank() == 0: os.makedirs(save_dir, exist_ok=True) dist.barrier() with FSDP.state_dict_type(online_model, StateDictType.FULL_STATE_DICT): if dist.get_rank() == 0: # Save LoRA adapter only draft_model.save_pretrained(save_dir) torch.save( { "epoch": epoch, "global_step": step, "args": args, **optimizer.state_dict(), }, os.path.join(save_dir, "training_state.pt"), ) print_on_rank0(f"Saved LoRA checkpoint to {save_dir}") dist.barrier() def record_metrics(args, loss, accuracy, global_step, tracker, optimizer, train_dataloader=None, mode="train"): logdict = {} if mode == "train" and optimizer is not None: logdict["train/lr"] = optimizer.get_learning_rate() logdict[f"{mode}/loss"] = loss logdict[f"{mode}/accuracy"] = accuracy print_on_rank0( f"{mode.capitalize()} - Step {global_step}, Loss: {loss:.4f}, Acc: {accuracy:.4f}" ) tracker.log(logdict, step=global_step) def main(): logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) warnings.filterwarnings( "ignore", "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", ) args = parse_args() set_seed(args.seed) # tp_size=1: LoRA training doesn't use tensor parallelism init_distributed(timeout=args.dist_timeout, tp_size=1) print_with_rank("Initialized distributed") # Load tokenizer and resolve mask_token_id tokenizer = AutoTokenizer.from_pretrained(args.model_path) if args.mask_token_id is not None: mask_token_id = args.mask_token_id elif tokenizer.mask_token_id is not None: mask_token_id = tokenizer.mask_token_id else: tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) mask_token_id = tokenizer.mask_token_id print_on_rank0(f"Using mask_token_id: {mask_token_id}") args.mask_token_id = mask_token_id draft_model, online_model = build_model(args) # Update mask_token_id in models after tokenizer resolution draft_model.mask_token_id = mask_token_id online_model.mask_token_id = mask_token_id if args.gradient_checkpointing: draft_model.gradient_checkpointing_enable( gradient_checkpointing_kwargs={"use_reentrant": False} ) print_on_rank0("Gradient checkpointing enabled") # Resume from checkpoint resume_state = None if args.ckpt_dir is not None: if os.path.isdir(args.ckpt_dir): print_on_rank0(f"Loading LoRA weights from {args.ckpt_dir}") from peft import PeftModel draft_model.model = PeftModel.from_pretrained( draft_model.model.base_model.model, args.ckpt_dir ) else: raise ValueError(f"ckpt_dir {args.ckpt_dir} is not a valid directory") if args.resume and os.path.isdir(args.output_dir): last_ckpt = get_last_checkpoint(args.output_dir, prefix=r"epoch_\d+_step") if last_ckpt: print_on_rank0(f"Resuming from {last_ckpt}") from peft import PeftModel draft_model.model = PeftModel.from_pretrained( draft_model.model.base_model.model, last_ckpt ) training_state_path = os.path.join(last_ckpt, "training_state.pt") if os.path.exists(training_state_path): resume_state = torch.load(training_state_path, map_location="cpu", weights_only=False) print_on_rank0( f"Will resume from epoch {resume_state['epoch']}, step {resume_state['global_step']}" ) train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) total_steps = args.num_epochs * steps_per_epoch print_on_rank0(f"Total training steps: {total_steps}") # Wrap with FSDP (only LoRA params will have gradients) online_model = FSDP( online_model, use_orig_params=True, mixed_precision=MixedPrecision( param_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ), sharding_strategy=ShardingStrategy.NO_SHARD, ) print_with_rank("Initialized FSDP") optimizer = BF16Optimizer( draft_model, lr=args.learning_rate, max_grad_norm=args.max_grad_norm, warmup_ratio=args.warmup_ratio, total_steps=total_steps, use_fp32_params=not args.no_fp32_params, optimizer_type=args.optimizer_type, ) start_epoch = 0 global_step = 0 if resume_state is not None: optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) start_epoch = resume_state["epoch"] global_step = resume_state["global_step"] del resume_state print_on_rank0(f"Restored scheduler, lr={optimizer.get_learning_rate():.6f}") skip_steps = global_step - start_epoch * len(train_dataloader) tracker = create_tracker(args, args.output_dir) last_time = time.time() print_on_rank0(f"Starting training from epoch {start_epoch}, step {global_step}") for epoch in range(start_epoch, args.num_epochs): train_dataloader.sampler.set_epoch(epoch) draft_model.train() if dist.get_rank() == 0: progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}", leave=True) else: progress_bar = train_dataloader for step_in_epoch, data in enumerate(progress_bar): if epoch == start_epoch and step_in_epoch < skip_steps: continue global_step += 1 input_ids = data["input_ids"].cuda() attention_mask = data["attention_mask"].cuda() loss_mask = data["loss_mask"].cuda() # Skip gradient sync during accumulation steps (only sync at optimizer step) is_accumulation_step = (global_step % args.accumulation_steps) != 0 ctx = online_model.no_sync() if is_accumulation_step else nullcontext() with ctx: loss, accuracy = online_model( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, context_len=args.context_len, ) (loss / args.accumulation_steps).backward() if global_step % args.accumulation_steps == 0: optimizer.step() if global_step % args.log_interval == 0: loss_val = loss.item() acc_val = accuracy.item() loss_t = torch.tensor(loss_val, device="cuda") acc_t = torch.tensor(acc_val, device="cuda") dist.all_reduce(loss_t) dist.all_reduce(acc_t) record_metrics(args, loss_t.item() / dist.get_world_size(), acc_t.item() / dist.get_world_size(), global_step, tracker, optimizer, train_dataloader, mode="train") if dist.get_rank() == 0: elapsed = time.time() - last_time last_time = time.time() progress_bar.set_postfix({ "loss": f"{loss.item():.4f}", "acc": f"{accuracy.item():.4f}", "iter_time": f"{elapsed:.2f}s", }) if global_step % args.save_interval == 0: save_checkpoint(args, epoch, global_step, online_model, draft_model, optimizer) save_checkpoint(args, args.num_epochs, global_step, online_model, draft_model, optimizer) tracker.close() destroy_distributed() if __name__ == "__main__": main()