| import torch |
| import safetensors.torch |
| from transformers import T5Tokenizer, T5EncoderModel |
|
|
| |
| input_diffusion = "mp_rank_00_model_states.pt" |
|
|
| |
| input_bert = "pytorch_model.bin" |
|
|
| |
| |
| |
| input_vae = "sdxl_vae.safetensors" |
|
|
| output = "freeway_animation_demo_hunyuan_dit.safetensors" |
|
|
|
|
|
|
| mt5 = T5EncoderModel.from_pretrained("google/mt5-xl") |
| tokenizer = T5Tokenizer.from_pretrained("google/mt5-xl") |
|
|
| sp_model = torch.ByteTensor(list(tokenizer.sp_model.serialized_model_proto())) |
| t5_sd = mt5.state_dict() |
|
|
| out_sd = {} |
|
|
| out_sd["text_encoders.mt5xl.spiece_model"] = sp_model |
|
|
| for k in t5_sd: |
| out_sd["text_encoders.mt5xl.transformer.{}".format(k)] = t5_sd[k].half() |
|
|
| bert_sd = torch.load(input_bert, weights_only=True) |
| for k in bert_sd: |
| if not k.startswith("visual."): |
| out_sd["text_encoders.hydit_clip.transformer.{}".format(k)] = bert_sd[k].half() |
|
|
| del bert_sd, mt5, t5_sd |
|
|
| hydit = torch.load(input_diffusion, weights_only=False)['ema'] |
| for k in hydit: |
| out_sd["model.{}".format(k)] = hydit[k].half() |
|
|
|
|
| vae_sd = safetensors.torch.load_file(input_vae) |
|
|
| for k in vae_sd: |
| out_sd["vae.{}".format(k)] = vae_sd[k].half() |
|
|
| safetensors.torch.save_file(out_sd, output) |
|
|