Inference Techniques for Cosmological Forward Modeling

EuCAPT Symposium, May 23rd 2022


François Lanusse










slides at eiffl.github.io/EuCAPT2022

the Rubin Observatory Legacy Survey of Space and Time

  • 1000 images each night, 15 TB/night for 10 years

  • 18,000 square degrees, observed once every few days

  • Tens of billions of objects, each one observed $\sim1000$ times

Previous generation survey: SDSS




















Image credit: Peter Melchior

Current generation survey: DES




















Image credit: Peter Melchior

LSST precursor survey: HSC




















Image credit: Peter Melchior

The limits of traditional cosmological inference

HSC cosmic shear power spectrum
HSC Y1 constraints on $(S_8, \Omega_m)$
(Hikage et al. 2018)
  • Measure the ellipticity $\epsilon = \epsilon_i + \gamma$ of all galaxies
    $\Longrightarrow$ Noisy tracer of the weak lensing shear $\gamma$

  • Compute summary statistics based on 2pt functions,
    e.g. the power spectrum

  • Run an MCMC to recover a posterior on model parameters, using an analytic likelihood $$ p(\theta | x ) \propto \underbrace{p(x | \theta)}_{\mathrm{likelihood}} \ \underbrace{p(\theta)}_{\mathrm{prior}}$$
Main limitation: the need for an explicit likelihood
We can only compute the likelihood for simple summary statistics and on large scales

$\Longrightarrow$ We are dismissing a significant fraction of the information!

A visual illustration of the impact of analytic assumptions


Jeffrey, Lanusse, et al. 2020

  • Cosmological signals exhibit significant departures from Gaussianity.


$\Longrightarrow$ This is the end of the analytic era...

A different road: forward modeling

  • Instead of trying to analytically evaluate the likelihood $p(x | \theta)$, let us build a forward model of the observables.
    $\Longrightarrow$ The simulator becomes the physical model.

  • Each component of the model is now tractable, but at the cost of a large number of latent variables.


Benefits of a forwrard modeling approach
  • No assumpmtion/approximation of Gaussianity of summary statistics, no need to compute covariances.
  • Fully exploits the information content of the data (aka "full field inference").
  • Easy to incorporate systematic effects.
  • Easy to combine mulitple probes by joint simulations.
(Schneider et al. 2015)

...so why is this not mainstream?

The Challenge of Simulation-Based Inference
$$ p(x|\theta) = \int p(x, z | \theta) dz = \int p(x | z, \theta) p(z | \theta) dz $$ Where $z$ are stochastic latent variables of the simulator.
$\Longrightarrow$ This marginal likelihood is intractable! Hence the phrase "Likelihood-Free Inference"

Outline for this talk



How to perform efficient inference over forward simulation models?

  • Likelihood-Free Inference: Treat the simulator as a black-box
    • Neural Density Estimation
    • Dimensionality Reduction


  • Hierarchical Bayesian Inference: Treat the simulator as a probabilistic model
    • Automatically Differentiable Physics
    • Gradient-based inference techniques

Likelihood-Free approach to
Simulation-Based Inference

Black-box Simulators Define Implicit Distributions

  • A black-box simulator defines $p(x | \theta)$ as an implicit distribution, you can sample from it but you cannot evaluate it.
  • Key Idea: Use a parametric distribution model $\mathbb{P}_\varphi$ to approximate the implicit distribution $\mathbb{P}$.

True $\mathbb{P}$

Samples $x_i \sim \mathbb{P}$

Model $\mathbb{P}_\varphi$

Why isn't it easy?


  • The curse of dimensionality put all points far apart in high dimension

Distance between pairs of points drawn from a Gaussian distribution.

  • Classical methods for estimating probability densities, i.e. Kernel Density Estimation (KDE) start to fail in high dimension because of all the gaps

Deep Learning Approaches to Likelihood-Free Inference

A two-steps approach to Likelihood-Free Inference
  • Automatically learn an optimal low-dimensional summary statistic $$y = f_\varphi(x) $$
  • Use Neural Density Estimation to either:
    • build an estimate $p_\phi$ of the likelihood function $p(y \ | \ \theta)$ (Neural Likelihood Estimation)

    • build an estimate $p_\phi$ of the posterior distribution $p(\theta \ | \ y)$ (Neural Posterior Estimation)

Conditional Density Estimation with Neural Networks

  • I assume a forward model of the observations: \begin{equation} p( x ) = p(x | \theta) \ p(\theta) \nonumber \end{equation} All I ask is the ability to sample from the model, to obtain $\mathcal{D} = \{x_i, \theta_i \}_{i\in \mathbb{N}}$

  • I am going to assume $q_\phi(\theta | x)$ a parametric conditional density

  • Optimize the parameters $\phi$ of $q_{\phi}$ according to \begin{equation} \min\limits_{\phi} \sum\limits_{i} - \log q_{\phi}(\theta_i | x_i) \nonumber \end{equation} In the limit of large number of samples and sufficient flexibility \begin{equation} \boxed{q_{\phi^\ast}(\theta | x) \approx p(\theta | x)} \nonumber \end{equation}
$\Longrightarrow$ One can asymptotically recover the posterior by optimizing a parametric estimator over
the Bayesian joint distribution
$\Longrightarrow$ One can asymptotically recover the posterior by optimizing a Deep Neural Network over
a simulated training set.

Neural Density Estimation


Bishop (1994)
  • Mixture Density Networks (MDN) \begin{equation} p(\theta | x) = \prod_i \pi_i(x) \ \mathcal{N}\left(\mu_i(x), \ \sigma_i(x) \right) \nonumber \end{equation}

  • Flourishing Machine Learning literature on density estimators
    GLOW, (Kingma & Dhariwal, 2018)

A variety of algorithms

Lueckmann, Boelts, Greenberg, Gonçalves, Macke (2021)


A few important points:

  • Amortized inference methods, which estimate $p(\theta | x)$, can greatly speed up posterior estimation once trained.

  • Sequential Neural Posterior/Likelihood Estimation methods can actively sample simulations needed to refine the inference.

Automated Summary Statistics Extraction

  • Introduce a parametric function $f_\varphi$ to reduce the dimensionality of the data while preserving information.
Makinen, Charnock, Alsing, Wandelt (2021)
Information-based loss functions
  • Variational Mutual Information Maximization $$ \mathcal{L} \ = \ \mathbb{E}_{y, \theta} [ \log q_\phi(\theta | f_\varphi(x)) ] \leq I(Y; \Theta) $$
    Jeffrey, Alsing, Lanusse (2021)


  • Information Maximization Neural Network $$\mathcal{L} \ = \ - | \det \mathbf{F} | \ \mbox{with} \ \mathbf{F}_{\alpha, \beta} = tr[ \mu_{\alpha}^t C^{-1} \mu_{\beta} ] $$
    Charnock, Lavaux, Wandelt (2018)

Example of application: Likelihood-Free parameter inference with DES SV

Jeffrey, Alsing, Lanusse (2021)

Suite of N-body + raytracing simulations: $\mathcal{D}$

deep residual networks for lensing maps compression


  • Deep Residual Network $y = f_\phi(x)$ followed by mixture density network $q_\phi(\theta | y)$

  • Training on weak lensing maps simulated for different cosmologies




  • Optimization of the variational lower bound: $$\mathbb{E}_{(x, \theta) \in \mathcal{D}} [ \log q_\phi(\theta | f_\phi(y) ) ]$$

Estimating the likelihood by Neural Density Estimation


$\Longrightarrow$ We cannot assume a Gaussian likelihood for the summary $y = f_\phi(\kappa)$ but we can learn $p(y | \theta)$: Neural Likelihood Estimation.


Dinh et al. 2016
Neural Likelihood Estimation by Normalizing Flow
  • We use a conditional Normalizing Flow to build an explicit model for the likelihood function $$ \log p_\varphi (y | \theta)$$

  • In practice we use the pyDELFI package and an ensemble of NDEs for robustness.

  • Once learned, we can use the likelihood as part of a conventional MCMC chain


Parameter constraints from DES SV data

Main takeaways



Hierarchical Bayesian Inference
approach to Simulation-Based Inference

Simulators as Hierarchical Bayesian Models

  • If we have access to all latent variables $z$ of the simulator, then the joint log likelihood $p(x | z, \theta)$ is explicit.

  • We need to infer the joint posterior $p(\theta, z | x)$ before marginalization to yield $p(\theta | x) = \int p(\theta, z | x) dz$.
    $\Longrightarrow$ Extremely difficult problem as $z$ is typically very high-dimensional.

  • Necessitates inference strategies with access to gradients of the likelihood. $$\frac{d \log p(x | z, \theta)}{d \theta} \quad ; \quad \frac{d \log p(x | z, \theta)}{d z} $$ For instance: Maximum A Posterior estimation, Hamiltonian Monte-Carlo, Variational Inference.

the hammer behind the Deep Learning revolution: Automatic Differentation

  • Automatic differentiation allows you to compute analytic derivatives of arbitraty expressions:
    If I form the expression $y = a * x + b$, it is separated in fundamental ops: $$ y = u + b \qquad u = a * x $$ then gradients can be obtained by the chain rule: $$\frac{\partial y}{\partial x} = \frac{\partial y}{\partial u} \frac{ \partial u}{\partial x} = 1 \times a = a$$

  • This is a fundamental tool in Machine Learning, and autodiff frameworks include TensorFlow and PyTorch.


Enters JAX: NumPy + Autograd + GPU
  • JAX follows the NumPy api!
    
    							import jax.numpy as np
    						
  • Arbitrary order derivatives
  • Accelerated execution on GPU and TPU

Surely this won't scale to cosmological simulations!

the Fast Particle-Mesh scheme for N-body simulations

The idea: approximate gravitational forces by estimating densities on a grid.
  • The numerical scheme:

    • Estimate the density of particles on a mesh
      => compute gravitational forces by FFT

    • Interpolate forces at particle positions

    • Update particle velocity and positions, and iterate

  • Fast and simple, at the cost of approximating short range interactions.
$\Longrightarrow$ Only a series of FFTs and interpolations.

introducing FlowPM: Particle-Mesh Simulations in TensorFlow

Modi, Lanusse, Seljak (2020)

				      													import tensorflow as tf
				      													import flowpm
				      													# Defines integration steps
				      													stages = np.linspace(0.1, 1.0, 10, endpoint=True)

				      													initial_conds = flowpm.linear_field(32,       # size of the cube
				      													                                   100,       # Physical size
				      													                                   ipklin,    # Initial powerspectrum
				      													                                   batch_size=16)

				      													# Sample particles and displace them by LPT
				      													state = flowpm.lpt_init(initial_conds, a0=0.1)

				      													# Evolve particles down to z=0
				      													final_state = flowpm.nbody(state, stages, 32)

				      													# Retrieve final density field
				      													final_field = flowpm.cic_paint(tf.zeros_like(initial_conditions),
				      													                               final_state[0])
				      												
  • Seamless interfacing with deep learning components
  • Mesh TensorFlow implementation for distribution on supercomputers









Mesh FlowPM: distributed, GPU-accelerated, and automatically differentiable simulations

  • We developed a Mesh TensorFlow implementation that can scale on GPU clusters (horovod+NCCL).


  • For a $2048^3$ simulation:
    • Distributed on 256 NVIDIA V100 GPUs
    • Runtime: 3 mins


  • Don't hesitate to reach out if you have a use case for model parallelism!

Example use-case: reconstructing initial conditions by MAP optimization


Going back to simpler times...
$$\arg\max_z \ \log p(x_{dm} = f(z)) \ + \ p(z) $$ where:
  • $f$ is FlowPM
  • $z$ are the initial conditions (early universe)
  • $x_{dm}$ is the present day dark matter distribution

MAP optimization in action

$$\arg\max_z \ \log p(x_{dm} = f(z)) \ + \ p(z) $$
credit: C. Modi


True initial conditions
$z_0$

Reconstructed initial conditions $z$

Reconstructed dark matter distribution $x = f(z)$

Data
$x_{DM} = f(z_0)$


Example use-case: Baryon Acoustic Oscillations

Most robust probes for dark energy from galaxy surveys like DESI

  • Pressure waves in the primordial photon-baryon fluids

  • Frozen at the time of decoupling
    $\rightarrow$ Clustering at preferred scale

  • Damping due to non-linear evolution

  • Fisher information

    $F_{ij} = V_{\rm{eff}}\int_{k_{min}}^{k_{max}} \frac{\partial {\rm ln} P(k)}{\partial p_i} \frac{\partial {\rm ln} P(k)}{\partial p_j}\frac{4 \pi k^2}{2 (2\pi)^3} dk \quad \mathbf{\propto \, r_c^4} $
Padmanabhan et al. 2012
Seo et al. 2007
Modi et al. 2018


Example of cosmological constraints on proof of concept: BORG-WL

Porqueres, Heavens, Mortlock, Lavaux (2021)

Main takeaways



  • This approach extends traditional inference to large scale problem
    • It requires a simulator implemented in a framework in which you have access to gradients, and tracking of all latent variables.


  • In theory optimal, but high-dimensional inference remains very hard.



  • Some resources and links:
    • Frameworks for differentiable simulations: MADLens, BORG

Conclusion

Conclusion



Methodology for inference over simulators
  • A change of paradigm from analytic likelihoods to simulators as physical model.

    • State of the art Machine Learning models enable Likelihood-Free Inference over black-box simulators.

    • Progress in differentiable simulators and inference methodology paves the way to full inference over probabilistic model.

  • Ultimately, promises optimal exploitation of survey data, although the "information gap" agains analytic likelihoods in realistic settingns remains uncertain.


Thank you!

Extra slides

jax-cosmo: Finally a differentiable cosmology library, and it's in JAX!


							import jax.numpy as np
							import jax_cosmo as jc

							# Defining a Cosmology
							cosmo = jc.Planck15()

							# Define a redshift distribution with smail_nz(a, b, z0)
							nz = jc.redshift.smail_nz(1., 2., 1.)

							# Build a lensing tracer with a single redshift bin
							probe = probes.WeakLensing([nz])

							# Compute angular Cls for some ell
							ell = np.logspace(0.1,3)
							cls = angular_cl(cosmo_jax, ell, [probe])
						
Current main features
  • Weak Lensing and Number counts probes
  • Eisenstein & Hu (1998) power spectrum + halofit
  • Angular $C_\ell$ under Limber approximation
$\Longrightarrow$ 3x2pt DES Y1 capable

Validating against the DESC Core Cosmology Library

let's compute a Fisher matrix


$$F = - \mathbb{E}_{p(x | \theta)}[ H_\theta(\log p(x| \theta)) ] $$

		import jax
		import jax.numpy as np
		import jax_cosmo as jc

		# .... define probes, and load a data vector

		def gaussian_likelihood( theta ):
			# Build the cosmology for given parameters
			cosmo = jc.Planck15(Omega_c=theta[0], sigma8=theta[1])

			# Compute mean and covariance
			mu, cov = jc.angular_cl.gaussian_cl_covariance_and_mean(cosmo,
																												ell, probes)
			# returns likelihood of data under model
			return jc.likelihood.gaussian_likelihood(data, mu, cov)

		# Fisher matrix in just one line:
		F = - jax.hessian(gaussian_likelihood)(theta)
		
Open In Colab


  • No derivatives were harmed by finite differences in the computation of this Fisher!
  • Only a small additional compute time compared to one forward evaluation of the model

Inference becomes fast and scalable

  • Current cosmological MCMC chains take days, and typically require access to large computer clusters.

  • Gradients of the log posterior are required for modern efficient and scalable inference techniques:
    • Variational Inference
    • Hamiltonian Monte-Carlo

  • In jax-cosmo, we can trivially obtain exact gradients:
    
    											def log_posterior( theta ):
    													return gaussian_likelihood( theta ) + log_prior(theta)
    
    											score = jax.grad(log_posterior)(theta)
    											

  • On a DES Y1 analysis, we find convergence in 70,000 samples with vanilla HMC, 140,000 with Metropolis-Hastings

DES Y1 posterior, jax-cosmo HMC vs Cobaya MH
(credit: Joe Zuntz)