A GPU-Accelerated JAX Framework for Robust Parametric Component Separation and Clustering Optimization for CMB Polarization Satellites
A GPU-Accelerated JAX Framework for Robust Parametric Component Separation and Clustering Optimization for CMB Polarization Satellites
Wassim Kabalan, Arianna Rizzieri, Wuhyun Sohn, Artem Basyrov, Alexandre Boucaud, Benjamin Beringue, Pierre Chanial, Ema Tsang King Sang, Josquin Errard
AbstractWe present a novel, JAX-powered implementation of a parametric component-separation method for CMB polarization data, explicitly designed to handle spatially varying foreground Spectral Energy Distributions (SEDs). The approach models this variation across the sky by grouping sets of pixels that share common foreground spectral parameters, scanning over thousands of such configurations to evaluate the trade-off between model complexity and residual systematic contamination. Built within the FURAX framework -- a JAX-powered environment for CMB data analysis -- our pipeline extends the fgbuster parametric formalism. It enables fully vectorized, GPU-accelerated evaluation of the spectral likelihood, map reconstruction, and diagnostic metrics across tens of thousands of pixel subset configurations, noise realizations, and sky regions. Our implementation achieves up to $\sim 100\times$ speed-up over the scipy TNC optimizer used in fgbuster when running on GPUs, as well as giving more robust results. When applied to LiteBIRD-like simulations with spatially varying foreground SEDs, our optimized K-means configuration reduces the 68% upper limit on the tensor-to-scalar ratio $r$ by $\approx 30\%$ relative to a fixed, previously derived multi-resolution configuration, while maintaining competitive statistical uncertainties.