|
|
"""
|
|
|
Pure PyTorch implementation of SoftPool.
|
|
|
This is a fallback that doesn't require CUDA kernel compilation.
|
|
|
SoftPool: https://arxiv.org/abs/2101.00440
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False):
|
|
|
"""
|
|
|
Apply soft pooling on 2D input tensor.
|
|
|
|
|
|
SoftPool approximates max pooling while maintaining differentiability
|
|
|
by using exponential weighting: y = sum(x * exp(x)) / sum(exp(x))
|
|
|
|
|
|
Args:
|
|
|
x: Input tensor of shape (N, C, H, W)
|
|
|
kernel_size: Pooling kernel size
|
|
|
stride: Stride (defaults to kernel_size)
|
|
|
force_inplace: Unused, for API compatibility
|
|
|
|
|
|
Returns:
|
|
|
Pooled tensor
|
|
|
"""
|
|
|
if stride is None:
|
|
|
stride = kernel_size
|
|
|
|
|
|
if isinstance(kernel_size, int):
|
|
|
kernel_size = (kernel_size, kernel_size)
|
|
|
if isinstance(stride, int):
|
|
|
stride = (stride, stride)
|
|
|
|
|
|
|
|
|
batch, channels, height, width = x.shape
|
|
|
kh, kw = kernel_size
|
|
|
sh, sw = stride
|
|
|
|
|
|
|
|
|
out_h = (height - kh) // sh + 1
|
|
|
out_w = (width - kw) // sw + 1
|
|
|
|
|
|
|
|
|
|
|
|
x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride)
|
|
|
x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w)
|
|
|
|
|
|
|
|
|
x_max = x_unfold.max(dim=2, keepdim=True)[0]
|
|
|
exp_x = torch.exp(x_unfold - x_max)
|
|
|
|
|
|
|
|
|
softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8)
|
|
|
|
|
|
|
|
|
softpool = softpool.view(batch, channels, out_h, out_w)
|
|
|
|
|
|
return softpool
|
|
|
|
|
|
|
|
|
class SoftPool2d(nn.Module):
|
|
|
"""
|
|
|
SoftPool 2D Layer.
|
|
|
|
|
|
A differentiable pooling operation that approximates max pooling
|
|
|
using exponential weighting.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False):
|
|
|
super(SoftPool2d, self).__init__()
|
|
|
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
|
|
self.stride = stride if stride is not None else self.kernel_size
|
|
|
self.force_inplace = force_inplace
|
|
|
|
|
|
def forward(self, x):
|
|
|
return soft_pool2d(x, self.kernel_size, self.stride, self.force_inplace)
|
|
|
|
|
|
def extra_repr(self):
|
|
|
return f'kernel_size={self.kernel_size}, stride={self.stride}'
|
|
|
|