# coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in HuggingFace Transformers. # Portions of this code are adapted from: # - https://github.com/EleutherAI/gpt-neox (Apache License 2.0) # - https://github.com/huggingface/transformers (Apache License 2.0) # - https://github.com/SafeAILab/EAGLE (Apache License 2.0) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import json import os from abc import ABC, abstractmethod from typing import Optional import torch from huggingface_hub import snapshot_download from safetensors import safe_open from transformers.cache_utils import Cache from transformers.modeling_utils import PreTrainedModel from specforge.modeling._mask_utils import _expand_mask, _make_causal_mask class Eagle3DraftModel(PreTrainedModel, ABC): """ This is the base class for the Eagle3 draft model implementation. The child class needs to implement the abstract methods to support training with TTT. """ @abstractmethod def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """ Embed the input ids. """ @abstractmethod def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Project the concatenated hidden states from the high, medium and low layers to the target hidden size. """ @abstractmethod def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Compute the logits of the draft model. """ def prepare_decoder_attention_mask( self, attention_mask: torch.Tensor, hidden_states: torch.Tensor, batch_size: int, seq_length: int, past_key_values_length: int, ) -> torch.Tensor: """ Prepare the attention mask of the draft model. """ # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if seq_length > 1: combined_attention_mask = _make_causal_mask( (batch_size, seq_length), hidden_states.dtype, device=hidden_states.device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask( attention_mask, hidden_states.dtype, tgt_len=seq_length ).to(hidden_states.device) combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask @abstractmethod def backbone( self, input_embeds: torch.Tensor, hidden_states: torch.Tensor, cache_hidden: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor, past_key_values: Optional[Cache] = None, use_cache: bool = True, ) -> torch.Tensor: """ The backbone of the draft model. """ def freeze_embedding(self) -> None: """ Freeze the embeddings of the draft model so that they are not updated during training. """ self.embed_tokens.weight.requires_grad = False @torch.no_grad() def load_embedding( self, model_path: str, embedding_key: str = "model.embed_tokens.weight" ) -> None: """ Load the embedding of the draft model. Args: model_path (str): Path to the target model. Can be either a Hugging Face repository ID or a local directory path containing the model files. """ if os.path.exists(model_path): # model_path is a local directory # check if there is file ending with index.json glob_path = os.path.join(model_path, "*.index.json") index_json_path = glob.glob(glob_path) if len(index_json_path) == 0: # No index.json found, look for single model file safetensors_path = os.path.join(model_path, "model.safetensors") if os.path.exists(safetensors_path): with safe_open(safetensors_path, framework="pt") as f: self.embed_tokens.weight.copy_(f.get_tensor(embedding_key)) return pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") if os.path.exists(pytorch_model_path): state_dict = torch.load(pytorch_model_path, map_location="cpu") self.embed_tokens.weight.copy_(state_dict[embedding_key]) return raise FileNotFoundError( f"No index.json, model.safetensors or pytorch_model.bin found in {model_path}" ) if len(index_json_path) > 1: raise FileNotFoundError( f"Multiple index.json files found in {model_path}" ) index_json_path = index_json_path[0] with open(index_json_path, "r") as f: index_json = json.load(f) ckpt_file = index_json["weight_map"][embedding_key] if ckpt_file.endswith(".safetensors"): with safe_open( os.path.join(model_path, ckpt_file), framework="pt" ) as f: emb_tokens = f.get_tensor(embedding_key) else: state_dict = torch.load(os.path.join(model_path, ckpt_file)) emb_tokens = state_dict[embedding_key] self.embed_tokens.weight.copy_(emb_tokens) else: # this is the case where model_path is a huggingface repository # we first need to locate its local cache local_cache_path = snapshot_download(repo_id=model_path) self.load_embedding(local_cache_path, embedding_key) def load_vocab_mapping(self, file_path: str) -> None: """ Load the vocab buffers of the draft model. Args: file_path (str): The path to the vocab mapping file. """ assert hasattr(self, "t2d") and hasattr( self, "d2t" ), "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" vocab_mapping = torch.load(file_path) self.t2d.copy_(vocab_mapping["t2d"]) self.d2t.copy_(vocab_mapping["d2t"]) self.vocab_mapping_loaded = True