| import os |
| import unittest |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from accelerate.utils import set_seed |
|
|
| from specforge.distributed import init_distributed |
| from specforge.layers import VocabParallelEmbedding |
| from tests.utils import get_available_port |
|
|
|
|
| def run_embedding(rank, world_size, port): |
| os.environ["RANK"] = str(rank) |
| os.environ["WORLD_SIZE"] = str(world_size) |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = str(port) |
| init_distributed(tp_size=world_size) |
| set_seed(42) |
|
|
| |
| |
| |
| |
| data = torch.randint(0, 512, (1, 128)).long().cuda() |
| native_embedding = torch.nn.Embedding(512, 256, padding_idx=314).cuda() |
| sf_embedding = VocabParallelEmbedding(512, 256, padding_idx=314).cuda() |
| sf_embedding.load_state_dict(native_embedding.state_dict()) |
|
|
| |
| native_output = native_embedding(data) |
| sf_output = sf_embedding(data) |
|
|
| |
| assert torch.allclose( |
| native_output, sf_output, rtol=1e-5, atol=1e-5 |
| ), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}" |
|
|
| |
| |
| |
| |
| data = torch.randint(0, 355, (1, 128)).long().cuda() |
| native_embedding = torch.nn.Embedding(355, 256, padding_idx=314).cuda() |
| sf_embedding = VocabParallelEmbedding(355, 256, padding_idx=314).cuda() |
| sf_embedding.load_state_dict(native_embedding.state_dict()) |
|
|
| |
| native_output = native_embedding(data) |
| sf_output = sf_embedding(data) |
|
|
| |
| assert torch.allclose( |
| native_output, sf_output, rtol=1e-5, atol=1e-5 |
| ), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}" |
|
|
| dist.destroy_process_group() |
|
|
|
|
| class TestEmbedding(unittest.TestCase): |
|
|
| def test_embedding(self): |
| port = get_available_port() |
| mp.spawn(run_embedding, nprocs=2, args=(2, port)) |
|
|
| port = get_available_port() |
| mp.spawn(run_embedding, nprocs=1, args=(1, port)) |
|
|
|
|
| if __name__ == "__main__": |
| suite = unittest.TestSuite() |
| suite.addTest(unittest.makeSuite(TestEmbedding)) |
| runner = unittest.TextTestRunner(verbosity=2) |
| runner.run(suite) |
|
|