| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .tools.wan_vae_1d import WanVAE_ |
| |
|
| |
|
| | class VAEWanModel(nn.Module): |
| | def __init__( |
| | self, |
| | input_dim, |
| | mean_path=None, |
| | std_path=None, |
| | z_dim=256, |
| | dim=160, |
| | dec_dim=512, |
| | num_res_blocks=1, |
| | dropout=0.0, |
| | dim_mult=[1, 1, 1], |
| | temperal_downsample=[True, True], |
| | vel_window=[0, 0], |
| | **kwargs, |
| | ): |
| | super().__init__() |
| |
|
| | self.mean_path = mean_path |
| | self.std_path = std_path |
| | self.input_dim = input_dim |
| | self.z_dim = z_dim |
| | self.dim = dim |
| | self.dec_dim = dec_dim |
| | self.num_res_blocks = num_res_blocks |
| | self.dropout = dropout |
| | self.dim_mult = dim_mult |
| | self.temperal_downsample = temperal_downsample |
| | self.vel_window = vel_window |
| | self.RECONS_LOSS = nn.SmoothL1Loss() |
| | self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0) |
| | self.LAMBDA_VELOCITY = kwargs.get("LAMBDA_VELOCITY", 0.5) |
| | self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6) |
| |
|
| | if self.mean_path is not None: |
| | self.register_buffer( |
| | "mean", torch.from_numpy(np.load(self.mean_path)).float() |
| | ) |
| | else: |
| | self.register_buffer("mean", torch.zeros(input_dim)) |
| |
|
| | if self.std_path is not None: |
| | self.register_buffer( |
| | "std", torch.from_numpy(np.load(self.std_path)).float() |
| | ) |
| | else: |
| | self.register_buffer("std", torch.ones(input_dim)) |
| |
|
| | self.model = WanVAE_( |
| | input_dim=self.input_dim, |
| | dim=self.dim, |
| | dec_dim=self.dec_dim, |
| | z_dim=self.z_dim, |
| | dim_mult=self.dim_mult, |
| | num_res_blocks=self.num_res_blocks, |
| | temperal_downsample=self.temperal_downsample, |
| | dropout=self.dropout, |
| | ) |
| |
|
| | downsample_factor = 1 |
| | for flag in self.temperal_downsample: |
| | if flag: |
| | downsample_factor *= 2 |
| | self.downsample_factor = downsample_factor |
| |
|
| | def preprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1) |
| | return x |
| |
|
| | def postprocess(self, x): |
| | |
| | x = x.permute(0, 2, 1) |
| | return x |
| |
|
| | def forward(self, x): |
| | features = x["feature"] |
| | feature_length = x["feature_length"] |
| | features = (features - self.mean) / self.std |
| | |
| | batch_size, seq_len = features.shape[:2] |
| | mask = torch.zeros( |
| | batch_size, seq_len, dtype=torch.bool, device=features.device |
| | ) |
| | for i in range(batch_size): |
| | mask[i, : feature_length[i]] = True |
| |
|
| | x_in = self.preprocess(features) |
| | mu, log_var = self.model.encode( |
| | x_in, scale=[0, 1], return_dist=True |
| | ) |
| | z = self.model.reparameterize(mu, log_var) |
| | x_decoder = self.model.decode(z, scale=[0, 1]) |
| | x_out = self.postprocess(x_decoder) |
| |
|
| | if x_out.size(1) != features.size(1): |
| | min_len = min(x_out.size(1), features.size(1)) |
| | x_out = x_out[:, :min_len, :] |
| | features = features[:, :min_len, :] |
| | mask = mask[:, :min_len] |
| |
|
| | mask_expanded = mask.unsqueeze(-1) |
| | x_out_masked = x_out * mask_expanded |
| | features_masked = features * mask_expanded |
| | loss_recons = self.RECONS_LOSS(x_out_masked, features_masked) |
| | vel_start = self.vel_window[0] |
| | vel_end = self.vel_window[1] |
| | loss_vel = self.RECONS_LOSS( |
| | x_out_masked[..., vel_start:vel_end], |
| | features_masked[..., vel_start:vel_end], |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | T_latent = mu.size(2) |
| | mask_downsampled = torch.zeros( |
| | batch_size, T_latent, dtype=torch.bool, device=features.device |
| | ) |
| | for i in range(batch_size): |
| | latent_length = ( |
| | feature_length[i] + self.downsample_factor - 1 |
| | ) // self.downsample_factor |
| | mask_downsampled[i, :latent_length] = True |
| | mask_latent = mask_downsampled.unsqueeze(1) |
| |
|
| | |
| | kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) |
| | |
| | kl_masked = kl_per_element * mask_latent |
| | |
| | kl_loss = torch.sum(kl_masked) / ( |
| | torch.sum(mask_downsampled) * mu.size(1) |
| | ) |
| |
|
| | |
| | total_loss = ( |
| | self.LAMBDA_FEATURE * loss_recons |
| | + self.LAMBDA_VELOCITY * loss_vel |
| | + self.LAMBDA_KL * kl_loss |
| | ) |
| |
|
| | loss_dict = {} |
| | loss_dict["total"] = total_loss |
| | loss_dict["recons"] = loss_recons |
| | loss_dict["velocity"] = loss_vel |
| | loss_dict["kl"] = kl_loss |
| |
|
| | return loss_dict |
| |
|
| | def encode(self, x): |
| | x = (x - self.mean) / self.std |
| | x_in = self.preprocess(x) |
| | mu = self.model.encode(x_in, scale=[0, 1]) |
| | mu = self.postprocess(mu) |
| | return mu |
| |
|
| | def decode(self, mu): |
| | mu_in = self.preprocess(mu) |
| | x_decoder = self.model.decode(mu_in, scale=[0, 1]) |
| | x_out = self.postprocess(x_decoder) |
| | x_out = x_out * self.std + self.mean |
| | return x_out |
| |
|
| | @torch.no_grad() |
| | def stream_encode(self, x, first_chunk=True): |
| | x = (x - self.mean) / self.std |
| | x_in = self.preprocess(x) |
| | mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1]) |
| | mu = self.postprocess(mu) |
| | return mu |
| |
|
| | @torch.no_grad() |
| | def stream_decode(self, mu, first_chunk=True): |
| | mu_in = self.preprocess(mu) |
| | x_decoder = self.model.stream_decode( |
| | mu_in, first_chunk=first_chunk, scale=[0, 1] |
| | ) |
| | x_out = self.postprocess(x_decoder) |
| | x_out = x_out * self.std + self.mean |
| | return x_out |
| |
|
| | def clear_cache(self): |
| | self.model.clear_cache() |
| |
|
| | def generate(self, x): |
| | features = x["feature"] |
| | feature_length = x["feature_length"] |
| | y_hat = self.decode(self.encode(features)) |
| |
|
| | y_hat_out = [] |
| |
|
| | for i in range(y_hat.shape[0]): |
| | |
| | valid_len = ( |
| | feature_length[i] - 1 |
| | ) // self.downsample_factor * self.downsample_factor + 1 |
| | |
| | y_hat_out.append(y_hat[i, :valid_len, :]) |
| |
|
| | out = {} |
| | out["generated"] = y_hat_out |
| | return out |
| |
|