Usage Guide
This chapter walks through the most common workflows in microJAX and explains what each knob does. The goal is to provide copy-and-pasteable snippets along with the context required to adapt them to your own microlensing problem.
Common setup
Start every session by enabling 64-bit mode and importing the building blocks you intend to use. Keeping everything in one place makes it easier to reuse the same configuration across notebooks or scripts:
import jax
import jax.numpy as jnp
from microjax.point_source import mag_point_source
from microjax.inverse_ray.lightcurve import mag_binary, mag_triple
jax.config.update("jax_enable_x64", True) # stabilises the polynomial solver
The snippets below assume this cell has already been run. If you restart your Python session, rerun it before continuing.
Point-source magnification
Use mag_point_source
when the source can be treated as infinitesimally small
and you need fast magnifications for one to three lenses.
Step-by-step
Assemble the complex source coordinates. The real part is the x-position, the imaginary part is the y-position in Einstein radii.
Specify the lens configuration via
nlenses
and the associated parameters.Call
mag_point_source
; the function broadcasts across any leading axes ofw
so batches are handled automatically.
Example:
w = jnp.array([
0.00 + 0.10j,
0.05 + 0.05j,
-0.10 + 0.02j,
])
mu = mag_point_source(w, nlenses=2, s=1.0, q=0.01)
print("Magnification per sample:", mu)
nlenses=3
introduces a third body. Provide the additional keywords q3
(mass ratio of lens 3 to lens 1), r3
(distance between lens 1 and 3), and
psi
(position angle of lens 3, in radians). All other keyword arguments are
fully broadcastable and can be supplied as arrays if you want to sweep over a
grid of lens parameters.
Finite-source binary lenses
mag_binary
computes finite-source light curves by combining a fast
hexadecapole approximation with full inverse-ray integrations when required.
1. Build the trajectory
The helper below constructs a standard rectilinear trajectory. Feel free to replace it with your own sampler if you need orbital motion or parallax.
tE = 40.0 # Einstein time (days)
u0 = 0.05 # impact parameter
alpha = jnp.deg2rad(60.0) # trajectory angle in radians
t0 = 0.0 # time of closest approach
rho = 0.01 # source radius in Einstein units
t = t0 + jnp.linspace(-2 * tE, 2 * tE, 1024)
tau = (t - t0) / tE
y1 = -u0 * jnp.sin(alpha) + tau * jnp.cos(alpha)
y2 = u0 * jnp.cos(alpha) + tau * jnp.sin(alpha)
w_points = jnp.array(y1 + 1j * y2, dtype=complex) # source trajectory
2. Evaluate the magnification
Call mag_binary
with the trajectory, source radius, and lens parameters. To
start, stick with the defaults for the optional arguments and only adjust them
if you hit performance limits.
s = 0.95 # projected separation
q = 5e-4 # mass ratio (m2/m1)
mags = mag_binary(w_points, rho, s=s, q=q)
mag_binary
returns magnifications aligned with the input trajectory. If you
need fluxes, multiply by the intrinsic source flux and add blends or baselines
as appropriate.
Fine-tuning parameters
r_resolution
/th_resolution
Set the number of grid divisions in the radial and angular directions for the inverse-ray shooting method. Increasing these values improves the accuracy of the magnification calculation, but also raises computational and memory costs on GPUs. Users should adjust them according to their accuracy requirements and hardware limits.MAX_FULL_CALLS
Determines the maximum number of magnification points that are computed with the image-centered ray-shooting (ICRS) method. It sets an upper limit on the points that require finite-source calculations, with the remaining points evaluated using the hexadecapole approximation.chunk_size
Controls how many points are processed in parallel by the ICRS method viajax.vmap
. A larger value can improve GPU utilization but may exceed device memory, causing out-of-memory errors. Smaller values are safer but may slow down the computation. Users should tune this parameter based on their GPU capacity.Nlimb
Sets the number of source limb points used to construct annular sectors on the lens plane, where ray-shooting integrations are performed. In most cases, users do not need to change this value. Adjust it only if catastrophic errors appear in magnification calculations.
Triple lenses
Triple-lens finite-source calculations are handled by mag_triple
. The
inputs mirror the binary API, but you must describe the third body explicitly.
mags_triple = mag_triple(w_points, rho,
s=1.10, # separation between 1st and 2nd lenses
q=0.02, # mass ratio (m2/m1)
q3=0.50, # mass ratio (m3/m1)
r3=0.60, # separation between center of masss for m1/m2 and m3
psi=jnp.deg2rad(210.0) # angle of 3rd lens axis in radians
)
Guidelines:
Start with the same trajectory used for the binary case; only the lens system changes.
psi
is measured counter-clockwise from the lens 1–2 axis.
Autodiff and jit
All magnification routines are differentiable. Wrapping them in jax.jit
gives you compiled performance, and jax.jacfwd
provide derivatives for
inference.
from functools import partial
from jax import jacfwd, jit
def forward_model(q):
mags = mag_binary(w, rho, s=s, q=q)
return mags # replace with instrument model if needed
forward_jit = jit(forward_model)
J = jacfwd(forward_jit)(q)
Note: The reverse-mode automatic differentiation in microJAX
is currently
under development due to memory handling issues.
Trajectory helpers
For trajectories beyond straight lines, the microjax.trajectory
package
provides composable pieces:
microjax.trajectory.parallax
– annual parallax terms.
These components return arrays compatible with the w_points
input used above, so
you can drop them into mag_binary
/ mag_triple
without further changes.
Best practices
Keep 64-bit mode enabled for production runs; it significantly improves the stability of implicit differentiation through the polynomial solver.
Use
microjax.likelihood
to marginalise nuisance flux parameters instead of fitting them manually—this often reduces sampler autocorrelation.