Software Projects

  • galtab is a key Python package that enabled much of the work in my PhD thesis by pretabulating galaxy placeholders to improve prediction efficiency of my Counts-in-Cylinders estimator.

  • JaxTabCorr is a Python package that integrates classes from TabCorr and halotools into a differentiable prediction framework made possible by JAX autodiff libraries.

  • mocksurvey is a Python package used for constructing mock galaxy catalogs and perform mock surveys seeded from the UniverseMachine empirical model.

  • I am a contributor to halotools, which is a Python package that provides a wide array of models of the galaxy-halo connection.

Data Products

Mock Galaxy Catalogs

You can download my mock catalogs for PFS here (or here for the original May 2020 version).

My Data Science Projects

I have led the development of the following data science projects during my postdoc at Argonne National Lab. Continue scrolling to learn about the class of astrophyiscal models they are being applied to.

kdescent

Fitting a population model with many parameters to multi-dimensional demographics is a notoriously difficult machine learning problem. kdescent presents a flexible stochastic gradient descent solution. kdescent draws a very small, random "mini-batch" of the training data (n = 20 for this example) and constructs a kernel density estimation (KDE) of the distribution, using these 20 random kernel centers, as shown by the example below.


Figure 1: Each panel shows example training data, colored by a kernel weight. The kernel centers were chosen randomly from the training data. We adopt a Gaussian kernel with the Scott's rule bandwidth (stretched according to the inverse principle component transformation).


Figure 2: The left panel is colored by number density of the training data in small spatial bins. The right panel attempts to approximately reproduce the population distribution, while only knowing 20 points using KDE (basically just averaging the twenty kernels from above).

Now imagine we are modeling this population (for testing, we will simply populate samples from a 2D multivariate normal distribution, which means we are fitting 5-parameters: 2 means, 2 variances, and 1 correlation coefficient). Thanks to the JAX implementation of kdescent, we can fit these parameters using a gradient descent algorithm, such as Adam, defining the loss as the mean-squared error of the number density within each kernel. By re-drawing our kernels at each iteration, we are able to (a) fairly probe the training data and (b) avoid getting stuck in local minima. The animation below shows this 5-parameter fit in action.


Figure 3: Animation of 200 iterations of the Adam gradient descent algorithm, using the loss function described above. The left panel shows the parameters converging to their known true values. The right panel shows the model distribution moving, stretching, and rotating until it finally converges upon the correct distribution.

Particle Swarm Optimization

I am actively working on a particle swarm optimization (PSO) implementation that is suitable for massively parallel parameter optimizations. A benefit of particle swarm optimization is that it doesn't require gradient calculations, so it works for non-differentiable models, and it is very fast and massively parallelizable. The current state of this project can be viewed in the mpipso package.


Figure 4: Proof of concept of a swarm of three particles converging upon the global minimum of a modified Himelblau objective function.

Parallel Gradients

The JAX Python library allows us to create extremely powerful, fast, differentiable models without requiring excessive development time. However, efficiently calculating a gradient in parallel is not trivial. This is a huge hurdle for big-data problems, where our data must be distributed across several nodes, and prevents us from taking advantage of the full processing power of each node. Therefore, I built a framework to simplify this process, and implemented it in the multidiff package. In brief, it works by allowing users to define functions that compute linear statistics that are summed over the data on each node. Not only are the resulting statistics computed automatically, but the chain rule is preserved by exploiting the vector-Jacobian product, allowing us to propagate the the derivatives to any desired loss function.


Figure 5: Subvolume division in which each partial gradient is computed before MPI combines everything with one simple addition reduction.

Astrophysics

Publications: See my papers on ADS

Postdoc at Argonne

I work closely with Andrew Hearin and others to develop models of the galaxy-halo connection, implemented with Python's JAX library to enable GPU acceleration and automatic differentiation. I am focusing on improving the scalability of our model to extremely large datasets by designing a framework that performs distributed parallel computation, while seamlessly preserving the advantages of JAX. I plan on utilizing this framework to make self-consistent mock observations on cosmological simulations, thereby minimizing biases in the joint inference of cosmology and galaxy formation physics.

PhD Thesis

Illuminating and Tabulating the Galaxy-Halo Connection

Part I: Illuminated the UniverseMachine to construct PFS mock catalogs

Using UniverseMachine as a model and UltraVISTA photometry as training data, I created a mock galaxy catalog specifically tailored to making predictions for the upcoming PFS survey. Using this mock, I published a paper which demonstrated that future extensions of the PFS survey should prioritize increasing the survey area to best improve scientific goals. This mock and the methods used to create it are publicly available.

Part II: Tabulated statistical estimators to be fast, precise, and differentiable

The galaxy-halo connection is typically analyzed via Markov-chain Monte Carlo (MCMC) sampling of parameter-space in order to place constraints on models. However, this process is slowed down by the stochastic nature of halo occupation distribution (HOD) models. I have improved the efficiency of this process with two open-source projects:

  • JaxTabCorr, in which I have rewritten parts of the TabCorr and halotools packages to replace certain NumPy operations with equivalent JAX operations. It can be used to calculate differentiable predictions of two-point correlation functions, which will improve the scalability of model inference as we need to push to larger and larger parameter spaces.
  • galtab, in which I have implemented a tabulation-accelerated statistic called Counts-in-Cylinders (CiC) that captures higher-order clustering information beyond that of the two-point correlation function. This code is also differentiable and automatically GPU-accelerated. I have published a paper presenting this code, as well as the new HOD constraints that it has made possible, utilizing the Early Data Release from the Dark Energy Spectroscopic Instrument (DESI).