microJAX
microJAX is a GPU-aware, auto-differentiable microlensing toolkit built on top of JAX. The library combines GPU-optimized image-centered inverse-ray shooting method and JAX-enabled XLA-acceralation to deliver fast and accurate magnifications and gradients for binary and triple lens systems.
Highlights
Accelerated finite sources – image-centered ray shooting (ICRS) with CUDA-ready batching.
Differentiable everywhere – gradients flow through polynomial solvers and ICRS for use in optimization and inference (e.g. HMC/VI) workflows.
Other Utilities – helpers for higher-order microlensing effects like orbital parallax, limb darkening, custom source motion, and more.
Composable likelihoods – analytic marginalisation utilities for inference.
Quick peek
Note: mag_binary also works on CPU but is very slow.
import jax
import jax.numpy as jnp
from microjax.point_source import mag_point_source
from microjax.inverse_ray.lightcurve import mag_binary
from microjax.point_source import critical_and_caustic_curves
jax.config.update("jax_enable_x64", True)
# Binary-lens parameters
s, q = 1.0, 0.01 # separation and mass ratio (m2/m1)
rho = 0.02 # source radius (Einstein units)
tE, u0 = 30.0, 0.0 # Einstein time [days], impact parameter
alpha = jnp.deg2rad(10.0) # trajectory angle in radian
t0 = 0.0
# Source trajectory
N_points = 1000
t = t0 + jnp.linspace(-tE, tE, N_points)
tau = (t - t0)/tE
y1 = -u0*jnp.sin(alpha) + tau*jnp.cos(alpha)
y2 = u0*jnp.cos(alpha) + tau*jnp.sin(alpha)
w_points = jnp.array(y1 + y2 * 1j, dtype=complex)
# Point-source and Extended-source magnifications (binary lens)
mag_p = mag_point_source(w_points, s=s, q=q, nlenses=2)
mag_ext = mag_binary(w_points, rho, s=s, q=q)
# Critical and caustic curves
crit, cau = critical_and_caustic_curves(s=s, q=q, nlenses=2, npts=1000)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(t, mag_p, 'k--', label='Point Source')
ax[0].plot(t, mag_ext, 'r-', label='Extended Source')
ax[0].set_xlabel('Time (days)')
ax[0].set_ylabel('Magnification')
ax[0].set_yscale('log')
ax[0].legend()
ax[1].plot(cau.real, cau.imag, 'r.')
ax[1].plot(w_points.real, w_points.imag, 'b-')
ax[1].axis('equal')
plt.show()
Use the sections below to install the package, explore worked examples, and dig into the API.
API Reference
Indices and tables
Citing microJAX
If you use microJAX in academic work, please cite the methods paper and the Zenodo software archive:
Miyazaki, S., & Kawahara, H. 2025, ApJ, 994, 144, doi:10.3847/1538-4357/ae1005
microJAX software archive (Zenodo): doi:10.5281/zenodo.17247892
BibTeX
@ARTICLE{2025ApJ...994..144M,
author = {{Miyazaki}, Shota and {Kawahara}, Hajime},
title = {microJAX: A Differentiable Framework for Microlensing Modeling with GPU-accelerated Image-centered Ray Shooting},
journal = {\apj},
year = 2025,
month = dec,
volume = {994},
number = {2},
eid = {144},
pages = {144},
doi = {10.3847/1538-4357/ae1005},
archivePrefix = {arXiv},
eprint = {2510.02639},
primaryClass = {astro-ph.EP},
adsurl = {https://ui.adsabs.harvard.edu/abs/2025ApJ...994..144M},
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
@software{microjax_zenodo_17247892,
author = {Miyazaki, Shota},
title = {microJAX},
year = {2025},
publisher = {Zenodo},
doi = {10.5281/zenodo.17247892},
url = {https://doi.org/10.5281/zenodo.17247892}
}
References
License & Attribution
Copyright 2025, Contributors
Shota Miyazaki (@ShotaMiyazaki94, maintainer)
Hajime Kawahara (@HajimeKawahara, co-maintainer)