import unittest import torch import torch._dynamo as dynamo from transformers import LlamaConfig from transformers.cache_utils import DynamicCache from specforge.modeling.draft.flex_attention import ( compile_friendly_create_block_mask, compile_friendly_flex_attention, generate_eagle3_mask, ) from specforge.modeling.draft.llama3_eagle import ( LlamaAttention, LlamaFlexAttention, prepare_decoder_attention_mask, ) from specforge.utils import padding from .utils import norm_tensor dynamo.config.recompile_limit = 64 TTT_LENGTH = 7 torch.manual_seed(0) class TestFlexAttention(unittest.TestCase): def setUp(self): torch.manual_seed(0) self.config_dict = { "hidden_size": 128, "num_attention_heads": 8, "num_key_value_heads": 2, "max_position_embeddings": 4096, "rms_norm_eps": 1e-05, "vocab_size": 32000, "intermediate_size": 688, "hidden_act": "silu", "num_hidden_layers": 1, "torch_dtype": "float32", } self.config = LlamaConfig(**self.config_dict) self.seq_lengths = [128, 200, 256, 300, 512, 800, 1024, 2048] self.dtype = torch.float32 def test_forward_pass_comparison(self): """Test forward pass comparison between LlamaAttention and LlamaFlexAttention.""" for seq_len in self.seq_lengths: with self.subTest(seq_len=seq_len): self._test_forward_pass_comparison_for_seq_len(seq_len) def _test_forward_pass_comparison_for_seq_len(self, seq_len): """Helper method to test forward pass comparison for a specific sequence length.""" attention = LlamaAttention(self.config).to("cuda").to(self.dtype) flex_attention = LlamaFlexAttention(self.config).to("cuda").to(self.dtype) # Ensure same weights with torch.no_grad(): flex_attention.q_proj.weight.copy_(attention.q_proj.weight) flex_attention.k_proj.weight.copy_(attention.k_proj.weight) flex_attention.v_proj.weight.copy_(attention.v_proj.weight) flex_attention.o_proj.weight.copy_(attention.o_proj.weight) attention.eval() flex_attention.eval() batch_size = 2 hidden_size = self.config.hidden_size * 2 ############### Attention Inputs ############## position_ids = ( torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to("cuda") ) cache_hidden = [[], []] # [cache_k, cache_v] attention_mask = torch.ones(batch_size, seq_len, dtype=self.dtype).to("cuda") # Simulate one item in the batch is masked and not taking a full block. padding_start_index = seq_len - min( 200, seq_len // 3 ) # Adjust padding based on seq_len attention_mask[1, padding_start_index:] = False input_embeds = norm_tensor( (batch_size, seq_len, self.config.hidden_size), device="cuda", dtype=self.dtype, ) decoder_attention_mask = prepare_decoder_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_len), inputs_embeds=input_embeds, past_key_values_length=0, ) hidden_states_list = [] flex_hidden_states_list = [] for idx in range(TTT_LENGTH): hidden_states = norm_tensor( (batch_size, seq_len, hidden_size), device="cuda", dtype=self.dtype ) flex_hidden_states = hidden_states.clone().detach() hidden_states_list.append(hidden_states) flex_hidden_states_list.append(flex_hidden_states) ############### Flex Attention Inputs ############## flex_position_ids = position_ids.clone() past_key_values = DynamicCache() for idx in range(TTT_LENGTH): is_last = idx == TTT_LENGTH - 1 with torch.no_grad(): output = attention( hidden_states=hidden_states_list[idx], attention_mask=decoder_attention_mask, position_ids=position_ids, cache_hidden=cache_hidden, output_attentions=False, use_cache=True, ) with torch.no_grad(): output_flex = flex_attention( hidden_states=flex_hidden_states_list[idx], attention_mask=attention_mask, position_ids=flex_position_ids, past_key_values=past_key_values, ) torch.testing.assert_close( output[0][: -1 - idx], output_flex[0][: -1 - idx], atol=1e-2, rtol=1e-2 ) torch.testing.assert_close( output[1][: padding_start_index - idx], output_flex[1][: padding_start_index - idx], atol=1e-2, rtol=1e-2, ) # Check output shape expected_output_shape = (batch_size, seq_len, self.config.hidden_size) self.assertEqual(output_flex.shape, expected_output_shape) # Check output is not NaN or Inf self.assertFalse(torch.isnan(output_flex).any()) self.assertFalse(torch.isinf(output_flex).any()) def test_backward_pass_gradient_comparison(self): """Test backward pass comparing gradients between LlamaAttention and LlamaFlexAttention.""" for seq_len in self.seq_lengths: with self.subTest(seq_len=seq_len): self._test_backward_pass_gradient_comparison_for_seq_len(seq_len) def _test_backward_pass_gradient_comparison_for_seq_len(self, seq_len): """Helper method to test backward pass gradient comparison for a specific sequence length.""" attention = LlamaAttention(self.config).to("cuda").to(self.dtype) flex_attention = LlamaFlexAttention(self.config).to("cuda").to(self.dtype) # Ensure same weights with torch.no_grad(): flex_attention.q_proj.weight.copy_(attention.q_proj.weight) flex_attention.k_proj.weight.copy_(attention.k_proj.weight) flex_attention.v_proj.weight.copy_(attention.v_proj.weight) flex_attention.o_proj.weight.copy_(attention.o_proj.weight) batch_size = 2 hidden_size = self.config.hidden_size * 2 ############### Attention Inputs ############## position_ids = ( torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to("cuda") ) cache_hidden = [[], []] # [cache_k, cache_v] attention_mask = torch.ones(batch_size, seq_len, dtype=torch.bool).to("cuda") # Simulate one item in the batch is masked and not taking a full block. # padding_start_index = seq_len - 50 # attention_mask[1, padding_start_index:] = False input_embeds = norm_tensor( (batch_size, seq_len, self.config.hidden_size), device="cuda", dtype=self.dtype, ) decoder_attention_mask = prepare_decoder_attention_mask( attention_mask=attention_mask, input_shape=(batch_size, seq_len), inputs_embeds=input_embeds, past_key_values_length=0, ) ############### Flex Attention Inputs ############## flex_position_ids = position_ids.clone() ttt_length = TTT_LENGTH past_key_values = DynamicCache() loss_mask = torch.ones( batch_size, seq_len, dtype=self.dtype, requires_grad=False ).to("cuda") # Create input tensors that require gradients loss_list = [] loss_flex_list = [] hidden_states_list = [] flex_hidden_states_list = [] for idx in range(TTT_LENGTH): hidden_states = norm_tensor( (batch_size, seq_len, hidden_size), device="cuda", dtype=self.dtype ) flex_hidden_states = hidden_states.clone().detach() hidden_states_list.append(hidden_states) flex_hidden_states_list.append(flex_hidden_states) for idx in range(TTT_LENGTH): is_last = idx == TTT_LENGTH - 1 output = attention( hidden_states=hidden_states_list[idx], attention_mask=decoder_attention_mask, position_ids=position_ids, cache_hidden=cache_hidden, output_attentions=False, use_cache=True, ) output_flex = flex_attention( hidden_states=flex_hidden_states_list[idx], attention_mask=attention_mask, position_ids=flex_position_ids, past_key_values=past_key_values, ) # Apply loss mask on calculation over batch loss = (output * loss_mask[..., None]).sum().mean() loss_flex = (output_flex * loss_mask[..., None]).sum().mean() torch.testing.assert_close(loss, loss_flex, atol=1e-2, rtol=1e-2) loss_list.append(loss) loss_flex_list.append(loss_flex) # Compare gradients if not is_last: # Step 5.7: we need to update the loss mask loss_mask = padding(loss_mask, left=False) mean_loss = sum(loss_list) / len(loss_list) mean_loss_flex = sum(loss_flex_list) / len(loss_flex_list) mean_loss.backward() mean_loss_flex.backward() projections = ["q_proj", "k_proj", "v_proj", "o_proj"] for proj_name in projections: torch.testing.assert_close( getattr(attention, proj_name).weight.grad, getattr(flex_attention, proj_name).weight.grad, atol=1e-2, rtol=1e-2, ) class TestEagle3FlexMask(unittest.TestCase): def test_eagle3_flex_mask(self): B = 1 H = 1 S = 128 * 8 D = 128 Q_LEN = S KV_LEN = S * 3 lck = 128 * 2 data_type = torch.bfloat16 query = norm_tensor((B, H, S, D), device="cuda", dtype=data_type) key_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type) value_cache = norm_tensor((B, H, KV_LEN, D), device="cuda", dtype=data_type) seq_lengths = torch.tensor([S], device="cuda", dtype=torch.int32) seq_lengths -= lck block_mask = compile_friendly_create_block_mask( mask_mod=generate_eagle3_mask( seq_lengths=seq_lengths, Q_LEN=Q_LEN, KV_LEN=KV_LEN, lck=lck ), B=1, H=1, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=query.device, ) # fmt: off expected_mask = torch.tensor([[[ [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ]]], dtype=torch.int32).to(query.device) # fmt: on dense_mask = block_mask.to_dense() assert torch.allclose(dense_mask, expected_mask) output = compile_friendly_flex_attention( query, key_cache, value_cache, block_mask=block_mask ) if __name__ == "__main__": unittest.main(verbosity=2)