""" Pure PyTorch implementation of SoftPool. This is a fallback that doesn't require CUDA kernel compilation. SoftPool: https://arxiv.org/abs/2101.00440 """ import torch import torch.nn as nn import torch.nn.functional as F def soft_pool2d(x, kernel_size=(2, 2), stride=None, force_inplace=False): """ Apply soft pooling on 2D input tensor. SoftPool approximates max pooling while maintaining differentiability by using exponential weighting: y = sum(x * exp(x)) / sum(exp(x)) Args: x: Input tensor of shape (N, C, H, W) kernel_size: Pooling kernel size stride: Stride (defaults to kernel_size) force_inplace: Unused, for API compatibility Returns: Pooled tensor """ if stride is None: stride = kernel_size if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) # Use unfold to extract patches batch, channels, height, width = x.shape kh, kw = kernel_size sh, sw = stride # Calculate output dimensions out_h = (height - kh) // sh + 1 out_w = (width - kw) // sw + 1 # Apply exponential weighting # For numerical stability, subtract max before exp x_unfold = F.unfold(x, kernel_size=kernel_size, stride=stride) # (N, C*kh*kw, out_h*out_w) x_unfold = x_unfold.view(batch, channels, kh * kw, out_h * out_w) # Softmax-style weighting for soft pooling x_max = x_unfold.max(dim=2, keepdim=True)[0] exp_x = torch.exp(x_unfold - x_max) # Numerical stability # Weighted sum: sum(x * exp(x)) / sum(exp(x)) softpool = (x_unfold * exp_x).sum(dim=2) / (exp_x.sum(dim=2) + 1e-8) # Reshape to output format softpool = softpool.view(batch, channels, out_h, out_w) return softpool class SoftPool2d(nn.Module): """ SoftPool 2D Layer. A differentiable pooling operation that approximates max pooling using exponential weighting. """ def __init__(self, kernel_size=(2, 2), stride=None, force_inplace=False): super(SoftPool2d, self).__init__() self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) self.stride = stride if stride is not None else self.kernel_size self.force_inplace = force_inplace def forward(self, x): return soft_pool2d(x, self.kernel_size, self.stride, self.force_inplace) def extra_repr(self): return f'kernel_size={self.kernel_size}, stride={self.stride}'