Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import os
import unittest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from accelerate.utils import set_seed
from specforge.distributed import gather_tensor, get_tp_group, init_distributed
from specforge.layers import ColumnParallelLinear, RowParallelLinear
from tests.utils import get_available_port
def run_column_parallel_linear(rank, world_size, port):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
init_distributed(tp_size=world_size)
set_seed(42)
# ===============================
# Case 1: normal layout
# ===============================
# create data
data = torch.rand(1, 256).cuda()
# create layers
native_linear = torch.nn.Linear(256, 512).cuda()
sf_linear = ColumnParallelLinear(256, 512, layout_type="normal").cuda()
sf_linear.load_state_dict(native_linear.state_dict())
# forward
native_output = native_linear(data)
sf_output = sf_linear(data)
full_sf_output = gather_tensor(sf_output, get_tp_group())
# check
assert torch.allclose(
native_output, full_sf_output, rtol=1e-5, atol=1e-5
), f"native_output: \n{native_output}, \nsf_output: \n{sf_output}"
# ===============================
# Case 2: merged QKV layout
# ===============================
# create data
data = torch.rand(1, 256 * 3).cuda()
# create layers
native_linear = torch.nn.Linear(256 * 3, 512).cuda()
sf_linear = ColumnParallelLinear(256 * 3, 512, layout_type="merged_qkv").cuda()
sf_linear.load_state_dict(native_linear.state_dict())
# forward
q, k, v = native_linear(data).chunk(3, dim=1)
sf_q, sf_k, sf_v = sf_linear(data).chunk(3, dim=1)
full_sf_q = gather_tensor(sf_q, get_tp_group())
full_sf_k = gather_tensor(sf_k, get_tp_group())
full_sf_v = gather_tensor(sf_v, get_tp_group())
# check
assert torch.allclose(
q, full_sf_q, rtol=1e-5, atol=1e-5
), f"q: \n{q}, \nfull_sf_q: \n{full_sf_q}"
assert torch.allclose(
k, full_sf_k, rtol=1e-5, atol=1e-5
), f"k: \n{k}, \nfull_sf_k: \n{full_sf_k}"
assert torch.allclose(
v, full_sf_v, rtol=1e-5, atol=1e-5
), f"v: \n{v}, \nfull_sf_v: \n{full_sf_v}"
# ===============================
# Case 3: gate_up layout
# ===============================
# create data
data = torch.rand(1, 256 * 2).cuda()
# create layers
native_linear = torch.nn.Linear(256 * 2, 512).cuda()
sf_linear = ColumnParallelLinear(256 * 2, 512, layout_type="gate_up").cuda()
sf_linear.load_state_dict(native_linear.state_dict())
# forward
gate, up = native_linear(data).chunk(2, dim=1)
sf_gate, sf_up = sf_linear(data).chunk(2, dim=1)
full_sf_gate = gather_tensor(sf_gate, get_tp_group())
full_sf_up = gather_tensor(sf_up, get_tp_group())
# check
assert torch.allclose(
gate, full_sf_gate, rtol=1e-5, atol=1e-5
), f"gate: \n{gate}, \nfull_sf_gate: \n{full_sf_gate}"
assert torch.allclose(
up, full_sf_up, rtol=1e-5, atol=1e-5
), f"up: \n{up}, \nfull_sf_up: \n{full_sf_up}"
dist.destroy_process_group()
def run_row_parallel_linear(rank, world_size, port):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
init_distributed(tp_size=world_size)
set_seed(42)
# ===============================
# Case 1: normal layout
# the data in an parallel input, i.g.
# Y = AllReduce(X_i * W_i)
# ===============================
# create data
data = torch.rand(1, 256).cuda()
# create layers
native_linear = torch.nn.Linear(256, 512).cuda()
sf_linear = RowParallelLinear(256, 512, layout_type="normal").cuda()
sf_linear.load_state_dict(native_linear.state_dict())
# forward
native_output = native_linear(data)
sf_output = sf_linear(data.chunk(world_size, dim=0)[rank])
dist.all_reduce(sf_output, op=dist.ReduceOp.SUM, group=get_tp_group())
# check
assert torch.allclose(
native_output, sf_output, rtol=1e-5, atol=1e-5
), f"native_output: \n{native_output}, \nfull_sf_output: \n{full_sf_output}"
class TestLinear(unittest.TestCase):
def test_column_parallel_linear(self):
port = get_available_port()
mp.spawn(run_column_parallel_linear, nprocs=2, args=(2, port))
def test_column_parallel_linear(self):
port = get_available_port()
mp.spawn(run_column_parallel_linear, nprocs=1, args=(1, port))
if __name__ == "__main__":
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestLinear))
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)