| |
|
|
| import math |
|
|
| import torch |
| from torch import nn |
|
|
| from .spherical_armonics import SH as SH_analytic |
|
|
|
|
| class SphericalHarmonics(nn.Module): |
| """ |
| Spherical Harmonics locaiton encoder |
| """ |
|
|
| def __init__(self, legendre_polys: int = 10, harmonics_calculation="analytic"): |
| """ |
| legendre_polys: determines the number of legendre polynomials. |
| more polynomials lead more fine-grained resolutions |
| calculation of spherical harmonics: |
| analytic uses pre-computed equations. This is exact, but works only up to degree 50, |
| closed-form uses one equation but is computationally slower (especially for high degrees) |
| """ |
| super(SphericalHarmonics, self).__init__() |
| self.L, self.M = int(legendre_polys), int(legendre_polys) |
| self.embedding_dim = self.L * self.M |
|
|
| if harmonics_calculation == "closed-form": |
| self.SH = SH_closed_form |
| elif harmonics_calculation == "analytic": |
| self.SH = SH_analytic |
|
|
| def forward(self, lonlat): |
| lon, lat = lonlat[:, 0], lonlat[:, 1] |
|
|
| |
| phi = torch.deg2rad(lon + 180) |
| theta = torch.deg2rad(lat + 90) |
| """ |
| greater_than_50 = (lon > 50).any() or (lat > 50).any() |
| if greater_than_50: |
| SH = SH_closed_form |
| else: |
| SH = SH_analytic |
| """ |
| SH = self.SH |
|
|
| Y = [] |
| for l in range(self.L): |
| for m in range(-l, l + 1): |
| y = SH(m, l, phi, theta) |
| if isinstance(y, float): |
| y = y * torch.ones_like(phi) |
| if y.isnan().any(): |
| print(m, l, y) |
| Y.append(y) |
|
|
| return torch.stack(Y, dim=-1) |
|
|
|
|
| |
| |
| |
| |
| def associated_legendre_polynomial(l, m, x): |
| pmm = torch.ones_like(x) |
| if m > 0: |
| somx2 = torch.sqrt((1 - x) * (1 + x)) |
| fact = 1.0 |
| for i in range(1, m + 1): |
| pmm = pmm * (-fact) * somx2 |
| fact += 2.0 |
| if l == m: |
| return pmm |
| pmmp1 = x * (2.0 * m + 1.0) * pmm |
| if l == m + 1: |
| return pmmp1 |
| pll = torch.zeros_like(x) |
| for ll in range(m + 2, l + 1): |
| pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m) |
| pmm = pmmp1 |
| pmmp1 = pll |
| return pll |
|
|
|
|
| def SH_renormalization(l, m): |
| return math.sqrt( |
| (2.0 * l + 1.0) * math.factorial(l - m) / (4 * math.pi * math.factorial(l + m)) |
| ) |
|
|
|
|
| def SH_closed_form(m, l, phi, theta): |
| if m == 0: |
| return SH_renormalization(l, m) * associated_legendre_polynomial( |
| l, m, torch.cos(theta) |
| ) |
| elif m > 0: |
| return ( |
| math.sqrt(2.0) |
| * SH_renormalization(l, m) |
| * torch.cos(m * phi) |
| * associated_legendre_polynomial(l, m, torch.cos(theta)) |
| ) |
| else: |
| return ( |
| math.sqrt(2.0) |
| * SH_renormalization(l, -m) |
| * torch.sin(-m * phi) |
| * associated_legendre_polynomial(l, -m, torch.cos(theta)) |
| ) |
|
|