| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from specforge.distributed import get_tp_group, shard_tensor |
|
|
|
|
| class ParallelLMHead(nn.Module): |
|
|
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| *, |
| bias: bool = True, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| super().__init__() |
| factory_kwargs = {"device": device, "dtype": dtype} |
| self.in_features = in_features |
| self.out_features = out_features |
| self.tp_group = get_tp_group() |
| self.tp_size = dist.get_world_size(self.tp_group) |
| self.tp_rank = dist.get_rank(self.tp_group) |
|
|
| |
| self.out_features_per_shard = math.ceil(out_features / self.tp_size) |
| self.padded_out_features = ( |
| self.out_features_per_shard * self.tp_size - out_features |
| ) |
| assert ( |
| self.out_features_per_shard * self.tp_size |
| == out_features + self.padded_out_features |
| ) |
|
|
| self.weight = nn.Parameter( |
| torch.empty(self.out_features_per_shard, self.in_features, **factory_kwargs) |
| ) |
| self.bias = ( |
| nn.Parameter(torch.zeros(self.out_features_per_shard, **factory_kwargs)) |
| if bias |
| else None |
| ) |
|
|
| |
| self.reset_parameters() |
|
|
| |
| self._register_load_state_dict_pre_hook(self.shard_state_dict) |
|
|
| def shard_state_dict(self, state_dict, *args): |
| if "weight" in state_dict: |
| value = state_dict["weight"] |
|
|
| |
| if value.shape[0] % self.tp_size != 0: |
| padding_size = self.tp_size - value.shape[0] % self.tp_size |
| value = F.pad(value, (0, 0, 0, padding_size)) |
| state_dict["weight"] = shard_tensor(value, self.tp_group, 0) |
|
|
| if "bias" in state_dict: |
| value = state_dict["bias"] |
|
|
| |
| if value.shape[0] % self.tp_size != 0: |
| padding_size = self.tp_size - value.shape[0] % self.tp_size |
| value = F.pad(value, (0, padding_size)) |
| state_dict["bias"] = shard_tensor(value, self.tp_group, 0) |
|
|
| def forward(self, hidden: torch.Tensor, gather_output: bool = False): |
| """ |
| hidden: [B, T, H] or [N, H] |
| returns: |
| - if gather_output=False: local logits [*, local_vocab] and (start,end) for stitching |
| - if gather_output=True: full logits [*, vocab] via all-gather (use for inference) |
| """ |
| orig_shape = hidden.shape |
| hidden = hidden.reshape(-1, self.in_features) |
|
|
| local_logits = hidden @ self.weight.T |
|
|
| if self.bias is not None: |
| local_logits = local_logits + self.bias |
|
|
| if not gather_output or self.tp_size == 1: |
| return local_logits.view( |
| *orig_shape[:-1], self.out_features_per_shard |
| ).contiguous() |
| else: |
| |
| chunks = [torch.empty_like(local_logits) for _ in range(self.tp_size)] |
| dist.all_gather(chunks, local_logits, group=self.tp_group) |
| full = torch.cat(chunks, dim=-1)[ |
| :, : self.out_features |
| ] |
| return full.view(*orig_shape[:-1], self.out_features).contiguous() |
|
|
| def reset_parameters(self): |
| nn.init.xavier_normal_(self.weight) |
| if self.bias is not None: |
| nn.init.zeros_(self.bias) |
|
|
| def __repr__(self): |
| return f"ParallelLMHead(in_features={self.in_features}, out_features={self.out_features_per_shard}, tp_size={self.tp_size}, tp_rank={self.tp_rank})" |
|
|