| import os |
| import unittest |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from accelerate.utils import set_seed |
|
|
| from specforge.distributed import gather_tensor, get_tp_group, init_distributed |
| from specforge.layers import ColumnParallelLinear, RowParallelLinear |
| from tests.utils import get_available_port |
|
|
|
|
| def run_column_parallel_linear(rank, world_size, port): |
| os.environ["RANK"] = str(rank) |
| os.environ["WORLD_SIZE"] = str(world_size) |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = str(port) |
| init_distributed(tp_size=world_size) |
| set_seed(42) |
|
|
| |
| |
| |
| |
| data = torch.rand(1, 256).cuda() |
|
|
| |
| native_linear = torch.nn.Linear(256, 512).cuda() |
| sf_linear = ColumnParallelLinear(256, 512, layout_type="normal").cuda() |
| sf_linear.load_state_dict(native_linear.state_dict()) |
|
|
| |
| native_output = native_linear(data) |
| sf_output = sf_linear(data) |
| full_sf_output = gather_tensor(sf_output, get_tp_group()) |
|
|
| |
| assert torch.allclose( |
| native_output, full_sf_output, rtol=1e-5, atol=1e-5 |
| ), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}" |
|
|
| |
| |
| |
| |
| data = torch.rand(1, 256 * 3).cuda() |
|
|
| |
| native_linear = torch.nn.Linear(256 * 3, 512).cuda() |
| sf_linear = ColumnParallelLinear(256 * 3, 512, layout_type="merged_qkv").cuda() |
| sf_linear.load_state_dict(native_linear.state_dict()) |
|
|
| |
| q, k, v = native_linear(data).chunk(3, dim=1) |
| sf_q, sf_k, sf_v = sf_linear(data).chunk(3, dim=1) |
| full_sf_q = gather_tensor(sf_q, get_tp_group()) |
| full_sf_k = gather_tensor(sf_k, get_tp_group()) |
| full_sf_v = gather_tensor(sf_v, get_tp_group()) |
|
|
| |
| assert torch.allclose( |
| q, full_sf_q, rtol=1e-5, atol=1e-5 |
| ), f"q: \n{q}, \nfull_sf_q: \n{full_sf_q}" |
| assert torch.allclose( |
| k, full_sf_k, rtol=1e-5, atol=1e-5 |
| ), f"k: \n{k}, \nfull_sf_k: \n{full_sf_k}" |
| assert torch.allclose( |
| v, full_sf_v, rtol=1e-5, atol=1e-5 |
| ), f"v: \n{v}, \nfull_sf_v: \n{full_sf_v}" |
|
|
| |
| |
| |
| |
| data = torch.rand(1, 256 * 2).cuda() |
|
|
| |
| native_linear = torch.nn.Linear(256 * 2, 512).cuda() |
| sf_linear = ColumnParallelLinear(256 * 2, 512, layout_type="gate_up").cuda() |
| sf_linear.load_state_dict(native_linear.state_dict()) |
|
|
| |
| gate, up = native_linear(data).chunk(2, dim=1) |
| sf_gate, sf_up = sf_linear(data).chunk(2, dim=1) |
| full_sf_gate = gather_tensor(sf_gate, get_tp_group()) |
| full_sf_up = gather_tensor(sf_up, get_tp_group()) |
|
|
| |
| assert torch.allclose( |
| gate, full_sf_gate, rtol=1e-5, atol=1e-5 |
| ), f"gate: \n{gate}, \nfull_sf_gate: \n{full_sf_gate}" |
| assert torch.allclose( |
| up, full_sf_up, rtol=1e-5, atol=1e-5 |
| ), f"up: \n{up}, \nfull_sf_up: \n{full_sf_up}" |
|
|
| dist.destroy_process_group() |
|
|
|
|
| def run_row_parallel_linear(rank, world_size, port): |
| os.environ["RANK"] = str(rank) |
| os.environ["WORLD_SIZE"] = str(world_size) |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = str(port) |
| init_distributed(tp_size=world_size) |
| set_seed(42) |
|
|
| |
| |
| |
| |
| |
| |
| data = torch.rand(1, 256).cuda() |
|
|
| |
| native_linear = torch.nn.Linear(256, 512).cuda() |
| sf_linear = RowParallelLinear(256, 512, layout_type="normal").cuda() |
| sf_linear.load_state_dict(native_linear.state_dict()) |
|
|
| |
| native_output = native_linear(data) |
| sf_output = sf_linear(data.chunk(world_size, dim=0)[rank]) |
| dist.all_reduce(sf_output, op=dist.ReduceOp.SUM, group=get_tp_group()) |
|
|
| |
| assert torch.allclose( |
| native_output, sf_output, rtol=1e-5, atol=1e-5 |
| ), f"native_output: \n{native_output}, \nfull_sf_output: \n{full_sf_output}" |
|
|
|
|
| class TestLinear(unittest.TestCase): |
|
|
| def test_column_parallel_linear(self): |
| port = get_available_port() |
| mp.spawn(run_column_parallel_linear, nprocs=2, args=(2, port)) |
|
|
| def test_column_parallel_linear(self): |
| port = get_available_port() |
| mp.spawn(run_column_parallel_linear, nprocs=1, args=(1, port)) |
|
|
|
|
| if __name__ == "__main__": |
| suite = unittest.TestSuite() |
| suite.addTest(unittest.makeSuite(TestLinear)) |
| runner = unittest.TextTestRunner(verbosity=2) |
| runner.run(suite) |
|
|