| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| class ResidualRenderBlock(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.Conv2d(dim, dim, kernel_size=3, padding=1), |
| nn.GroupNorm(8, dim), |
| nn.SiLU(), |
| nn.Conv2d(dim, dim, kernel_size=3, padding=1), |
| nn.GroupNorm(8, dim) |
| ) |
|
|
| def forward(self, x): |
| return x + self.block(x) |
|
|
| class RenderEncoder(nn.Module): |
| def __init__(self, encoder_type="1d", in_channels=768, out_channels=3): |
| super().__init__() |
| self.encoder_type = encoder_type |
|
|
| if encoder_type == "1d": |
| self.model = nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=1), |
| nn.Sigmoid() |
| ) |
|
|
| elif encoder_type == "residual": |
| self.model = ResidualBlockRender(in_channels, out_channels) |
|
|
| elif encoder_type == "expressive": |
| mid_channels = 256 |
| self.model = nn.Sequential( |
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), |
| nn.GroupNorm(8, mid_channels), |
| nn.SiLU(), |
| ResidualRenderBlock(mid_channels), |
| ResidualRenderBlock(mid_channels), |
| ResidualRenderBlock(mid_channels), |
| nn.Conv2d(mid_channels, out_channels, kernel_size=1), |
| nn.Sigmoid() |
| ) |
|
|
| else: |
| raise ValueError(f"Unknown encoder_type '{encoder_type}'. Use '1d', 'residual', or 'expressive'.") |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
| class ResidualBlockRender(nn.Module): |
| def __init__(self, in_channels=768, out_channels=3): |
| super().__init__() |
| self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1) |
| self.relu1 = nn.ReLU() |
| self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) |
| self.relu2 = nn.ReLU() |
| self.conv3 = nn.Conv2d(256, out_channels, kernel_size=1) |
| self.out = nn.Sigmoid() |
|
|
| if in_channels != out_channels: |
| self.residual_proj = nn.Conv2d(in_channels, out_channels, kernel_size=1) |
| else: |
| self.residual_proj = nn.Identity() |
|
|
| def forward(self, x): |
| residual = self.residual_proj(x) |
| h = self.relu1(self.conv1(x)) |
| h = self.relu2(self.conv2(h)) |
| h = self.conv3(h) |
| h = h + residual |
| return self.out(h) |
|
|
| def load_render_encoder(checkpoint_path, device='cpu'): |
| """Load standalone RenderEncoder from checkpoint""" |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
| config = checkpoint['model_config'] |
| model = RenderEncoder( |
| encoder_type=config['encoder_type'], |
| in_channels=config['in_channels'], |
| out_channels=config['out_channels'] |
| ) |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.to(device) |
| model.eval() |
| |
| print(f"Loaded RenderEncoder: {config}") |
| return model |