Hanrui / progress /github /SpecForge /tests /test_layers /test_embedding.py
Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import os
import unittest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from accelerate.utils import set_seed
from specforge.distributed import init_distributed
from specforge.layers import VocabParallelEmbedding
from tests.utils import get_available_port
def run_embedding(rank, world_size, port):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
init_distributed(tp_size=world_size)
set_seed(42)
# ===============================
# Case 1: vocab size is divisible by the TP size
# ===============================
# create layers
data = torch.randint(0, 512, (1, 128)).long().cuda()
native_embedding = torch.nn.Embedding(512, 256, padding_idx=314).cuda()
sf_embedding = VocabParallelEmbedding(512, 256, padding_idx=314).cuda()
sf_embedding.load_state_dict(native_embedding.state_dict())
# forward
native_output = native_embedding(data)
sf_output = sf_embedding(data)
# check
assert torch.allclose(
native_output, sf_output, rtol=1e-5, atol=1e-5
), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}"
# ===============================
# Case 2: vocab size is NOT divisible by the TP size
# ===============================
# create layers
data = torch.randint(0, 355, (1, 128)).long().cuda()
native_embedding = torch.nn.Embedding(355, 256, padding_idx=314).cuda()
sf_embedding = VocabParallelEmbedding(355, 256, padding_idx=314).cuda()
sf_embedding.load_state_dict(native_embedding.state_dict())
# forward
native_output = native_embedding(data)
sf_output = sf_embedding(data)
# check
assert torch.allclose(
native_output, sf_output, rtol=1e-5, atol=1e-5
), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}"
dist.destroy_process_group()
class TestEmbedding(unittest.TestCase):
def test_embedding(self):
port = get_available_port()
mp.spawn(run_embedding, nprocs=2, args=(2, port))
port = get_available_port()
mp.spawn(run_embedding, nprocs=1, args=(1, port))
if __name__ == "__main__":
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestEmbedding))
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)