Lekr0's picture
Add files using upload-large-folder tool
212a146 verified
import glob
import json
import os
from typing import Optional
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
from specforge.utils import padding
class TargetHead(nn.Module):
def __init__(self, model_path, trust_remote_code: bool = False):
super().__init__()
self.config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code
)
self.fc = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
@classmethod
def from_pretrained(
cls,
model_path,
lm_head_key: str = "lm_head.weight",
cache_dir: Optional[str] = None,
trust_remote_code: bool = False,
) -> "TargetHead":
target_head = cls(model_path, trust_remote_code=trust_remote_code)
target_head.load_weights(
model_path=model_path,
lm_head_key=lm_head_key,
cache_dir=cache_dir,
)
target_head.freeze_weights()
target_head = target_head.eval().cuda().to(torch.bfloat16)
return target_head
@torch.no_grad()
def load_weights(
self,
model_path,
lm_head_key: str = "lm_head.weight",
cache_dir: Optional[str] = None,
):
if os.path.exists(model_path):
self.model_path = model_path
else:
self.model_path = snapshot_download(repo_id=model_path)
# model_path is a local directory
# check if there is file ending with index.json
glob_path = os.path.join(self.model_path, "*.index.json")
index_json_path = glob.glob(glob_path)
if len(index_json_path) == 0:
raise FileNotFoundError(f"No index.json file found in {self.model_path}")
if len(index_json_path) > 1:
raise FileNotFoundError(
f"Multiple index.json files found in {self.model_path}"
)
index_json_path = index_json_path[0]
with open(index_json_path, "r") as f:
index_json = json.load(f)
ckpt_file = index_json["weight_map"][lm_head_key]
if ckpt_file.endswith(".safetensors"):
with safe_open(
os.path.join(self.model_path, ckpt_file), framework="pt"
) as f:
lm_head = f.get_tensor(lm_head_key)
else:
state_dict = torch.load(os.path.join(self.model_path, ckpt_file))
lm_head = state_dict[lm_head_key]
self.fc.weight.copy_(lm_head)
def freeze_weights(self):
for param in self.fc.parameters():
param.requires_grad = False
def forward(self, hidden_states):
return self.fc(hidden_states)
def preprocess(self, input_ids, target, loss_mask):
# apply pading
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
loss_mask = loss_mask[..., None]
return input_ids, target, loss_mask