File size: 14,344 Bytes
29658b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 | import os
import time
import unittest
import torch
import torch.multiprocessing as mp
from accelerate.utils import set_seed
from torch import nn
from transformers import PretrainedConfig
from yunchang import EXTRACT_FUNC_DICT
from specforge.core.eagle3_adapters import SdpaLikeAdapter, UspAdapter
from specforge.data.preprocessing import build_offline_eagle3_dataset
# Project-specific imports
from specforge.distributed import destroy_distributed, init_distributed
from specforge.modeling.draft.llama3_eagle import LlamaDecoderLayer
from specforge.utils import padding
from tests.utils import get_available_port
def get_model_config():
"""Create and return the model configuration."""
config_dict = {
"architectures": ["LlamaForCausalLMEagle3"],
"eagle_config": {
"eagle_aux_hidden_state_layer_ids": [1, 29, 57],
"use_aux_hidden_state": True,
},
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 29568,
"max_position_embeddings": 32768,
"model_type": "llama",
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.28.1",
"use_cache": True,
"rope_scaling": None,
"vocab_size": 129280,
"draft_vocab_size": 32000,
"pretraining_tp": 1,
}
return PretrainedConfig.from_dict(config_dict)
def setup_env(rank, world_size, port):
"""Set up distributed environment variables."""
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
torch.cuda.set_device(rank)
def dbg(rank, msg):
print(f"[rank{rank}] {msg}", flush=True)
def wait_for_file(path, timeout_s=60, poll_s=0.1):
start = time.time()
while time.time() - start < timeout_s:
if os.path.exists(path):
return True
time.sleep(poll_s)
return False
def run_iterative_pass(
decoder_layer,
embed_tokens,
input_ids,
hidden_states,
attention_mask,
position_ids,
ttt_length,
):
"""
Core loop: execute the forward pass `ttt_length` times.
Used for both Golden (SDPA) and Distributed (USP) runs to ensure logic consistency.
"""
# Clone to avoid side effects on original tensors
curr_input_ids = input_ids.clone()
curr_hidden_states = hidden_states.clone()
# Init cache
cache_hidden = [[], []]
past_key_values = None
final_output = None
for idx in range(ttt_length):
is_last = idx == ttt_length - 1
# 1. Embed inputs
inputs_embeds = embed_tokens(curr_input_ids).to(curr_hidden_states.dtype)
# 2. Forward pass
output_hidden_states = decoder_layer(
input_emb=inputs_embeds,
hidden_states=curr_hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=False,
use_cache=False,
)
# Update states for next iteration
curr_hidden_states = output_hidden_states
final_output = output_hidden_states
# 3. Simulate TTT padding/shift
if not is_last:
curr_input_ids = padding(curr_input_ids, left=False)
return final_output
def run_test_case(rank, world_size, port):
"""Worker function executed in each process."""
setup_env(rank, world_size, port)
device = torch.device(f"cuda:{rank}")
set_seed(42)
dbg(rank, "env setup complete")
# --- Data & Config Preparation ---
config = get_model_config()
seq_len = 1560
batch_size = 1
ttt_length = 3
# Generate dummy data on GPU
data_input_ids = torch.randint(0, 10000, (batch_size, seq_len), device=device)
data_hidden_states = torch.randn(
batch_size, seq_len, config.hidden_size, device=device, dtype=torch.bfloat16
)
attention_mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).view(
1, 1, seq_len, seq_len
)
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
# Shared embedding layer
embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
).to(device)
# --- Phase 1: Golden Run (FA) ---
# Init dist briefly for internal checks, even if running single-device logic
init_distributed(tp_size=1, sp_ulysses_size=1, sp_ring_size=1)
dbg(rank, "init_distributed (FA) done")
sdpa_decoder = (
LlamaDecoderLayer(config, attention_backend="fa").to(device).to(torch.bfloat16)
)
dbg(rank, "FA decoder created")
# Adapter smoke test for FA/SDPA-style path
dummy_model = type("Dummy", (), {})()
sdpa_adapter = SdpaLikeAdapter(dummy_model)
sdpa_target_p = torch.zeros((1, seq_len, 8), device=device, dtype=torch.float32)
sdpa_position_mask = torch.ones((1, seq_len, 1), device=device, dtype=torch.float32)
sdpa_state = sdpa_adapter.step_view(
idx=0,
ttt_length=ttt_length,
global_input_ids=data_input_ids,
attention_mask=attention_mask,
loss_mask=torch.ones((1, seq_len, 1), device=device, dtype=torch.float32),
position_ids=position_ids,
hidden_states=data_hidden_states,
target_p_padded=sdpa_target_p,
position_mask=sdpa_position_mask,
seq_length=seq_len,
)
assert sdpa_state.input_ids.shape[1] == seq_len
assert sdpa_state.hidden_states.shape[1] == seq_len
with torch.no_grad():
sdpa_output = run_iterative_pass(
decoder_layer=sdpa_decoder,
embed_tokens=embed_tokens,
input_ids=data_input_ids,
hidden_states=data_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
ttt_length=ttt_length,
)
dbg(rank, "FA forward done")
# Save weights for alignment and cleanup SDPA model
state_dict = sdpa_decoder.state_dict()
del sdpa_decoder
destroy_distributed()
dbg(rank, "destroy_distributed (FA) done")
# --- Phase 2: Distributed Run (USP) ---
def subtest_usp(sp_ulysses_degree, sp_ring_degree):
"""Run USP with specific topology and compare against Golden."""
try:
init_distributed(
tp_size=1,
sp_ulysses_size=sp_ulysses_degree,
sp_ring_size=sp_ring_degree,
)
dbg(
rank,
f"init_distributed (USP U{sp_ulysses_degree} R{sp_ring_degree}) done",
)
# Dataset + adapter smoke test (USP path)
tmp_dir = "./tmp/usp_dataset_shared"
try:
if rank == 0:
os.makedirs(tmp_dir, exist_ok=True)
sample = {
"input_ids": data_input_ids[0].cpu(),
"loss_mask": torch.ones_like(data_input_ids[0].cpu()),
"hidden_state": data_hidden_states[0].cpu().unsqueeze(0),
"aux_hidden_state": data_hidden_states[0].cpu().unsqueeze(0),
}
torch.save(sample, os.path.join(tmp_dir, "data_0.ckpt"))
dbg(rank, "wrote sample ckpt")
ready_flag = os.path.join(tmp_dir, "ready.flag")
with open(ready_flag, "w", encoding="utf-8") as f:
f.write("ready\n")
if rank != 0:
ready_flag = os.path.join(tmp_dir, "ready.flag")
assert wait_for_file(
ready_flag, timeout_s=60
), "timeout waiting for ready flag"
dbg(rank, "dataset sync done")
assert os.path.exists(
os.path.join(tmp_dir, "data_0.ckpt")
), f"Expected sample not found at {tmp_dir}"
dbg(rank, "sample exists")
ds = build_offline_eagle3_dataset(
tmp_dir,
max_len=seq_len,
ttt_length=ttt_length,
use_usp_preprocess=True,
)
dbg(rank, "dataset built")
item = ds[0]
dbg(rank, "dataset item loaded")
assert "position_ids" in item
dummy_model = type("Dummy", (), {})()
adapter = UspAdapter(dummy_model)
local_seq_len = item["input_ids"].shape[1]
target_p_padded = torch.zeros(
(1, local_seq_len, 8), device=device, dtype=torch.float32
)
position_mask = torch.ones(
(1, local_seq_len, 1), device=device, dtype=torch.float32
)
state = adapter.step_view(
idx=0,
ttt_length=ttt_length,
global_input_ids=item["input_ids"].to(device),
attention_mask=item["attention_mask"].to(device),
loss_mask=item["loss_mask"].to(device).unsqueeze(-1),
position_ids=item["position_ids"].to(device),
hidden_states=item["hidden_state"].to(device),
target_p_padded=target_p_padded,
position_mask=position_mask,
seq_length=local_seq_len,
)
assert state.input_ids.shape[1] == local_seq_len - ttt_length
assert state.hidden_states.shape[1] == local_seq_len - ttt_length
dbg(rank, "adapter step_view ok")
finally:
if rank == 0:
done_flag = os.path.join(tmp_dir, "done.flag")
assert wait_for_file(
done_flag, timeout_s=60
), "timeout waiting for done flag"
try:
for root, _, files in os.walk(tmp_dir):
for name in files:
os.remove(os.path.join(root, name))
os.rmdir(tmp_dir)
except OSError:
pass
else:
done_flag = os.path.join(tmp_dir, "done.flag")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
# Init USP model and load golden weights
usp_decoder = (
LlamaDecoderLayer(config, attention_backend="usp")
.to(device)
.to(torch.bfloat16)
)
usp_decoder.load_state_dict(state_dict)
dbg(rank, "USP decoder loaded")
# Shard data (Split Input)
extract_func = EXTRACT_FUNC_DICT["basic"]
local_input_ids = (
extract_func(
data_input_ids,
rank,
world_size=world_size,
rd=sp_ring_degree,
ud=sp_ulysses_degree,
)
.detach()
.clone()
)
local_hidden_states = (
extract_func(
data_hidden_states,
rank,
world_size=world_size,
rd=sp_ring_degree,
ud=sp_ulysses_degree,
)
.detach()
.clone()
)
dbg(rank, "USP local inputs prepared")
total_degree = sp_ring_degree * sp_ulysses_degree
chunk_size = sdpa_output.shape[1] // total_degree
start_idx = (rank % total_degree) * chunk_size
local_len = local_input_ids.shape[1]
local_position_ids = (
torch.arange(start_idx, start_idx + local_len, device=device)
.unsqueeze(0)
.long()
)
local_attention_mask = torch.tril(
torch.ones(local_len, local_len, device=device)
).view(1, 1, local_len, local_len)
# Run USP forward
if sp_ring_degree > 1:
usp_attention_mask = local_attention_mask
usp_position_ids = local_position_ids
else:
usp_attention_mask = attention_mask
usp_position_ids = position_ids
with torch.no_grad():
usp_output = run_iterative_pass(
decoder_layer=usp_decoder,
embed_tokens=embed_tokens,
input_ids=local_input_ids,
hidden_states=local_hidden_states,
attention_mask=usp_attention_mask,
position_ids=usp_position_ids,
ttt_length=ttt_length,
)
dbg(rank, "USP forward done")
# Verify results
# Slice the golden output to match the current rank's chunk
end_idx = start_idx + chunk_size
golden_chunk = sdpa_output[:, start_idx:end_idx, :]
assert torch.allclose(usp_output, golden_chunk, rtol=2e-2, atol=2e-2), (
f"[Rank {rank}] USP (U{sp_ulysses_degree}R{sp_ring_degree}) mismatch!\n"
f"Max Diff: {(usp_output - golden_chunk).abs().max().item()}"
)
dbg(rank, "USP output verified")
finally:
destroy_distributed()
dbg(rank, "destroy_distributed (USP) done")
# Case 1: Hybrid (Ulysses=2, Ring=1)
subtest_usp(sp_ulysses_degree=2, sp_ring_degree=1)
# Case 2: Hybrid (Ulysses=1, Ring=2)
subtest_usp(sp_ulysses_degree=1, sp_ring_degree=2)
class TestTTTDistributed(unittest.TestCase):
def test_llama_usp_decoder(self):
world_size = 2
port = get_available_port()
mp.spawn(run_test_case, nprocs=world_size, args=(world_size, port))
if __name__ == "__main__":
unittest.main()
|