Getting Started
Use this guide to prepare an environment, confirm that JAX detects your accelerator, and run a quick smoke test.
Prerequisites
Python 3.9 or newer with matching
jax
/jaxlib
wheels for your platform (CPU, CUDA, ROCm). Follow the official JAX installation matrix.Optional plotting stack:
matplotlib
orseaborn
if you plan to run the visualization examples.
Installation
Install the latest release from PyPI:
python -m pip install microjaxx
Or work from source:
git clone https://github.com/ShotaMiyazaki94/microjax.git
cd microjax
python -m pip install -e ".[dev]"
The import name remains microjax
even though the published wheel is
microjaxx
.
Verify the environment
Run the snippet below to confirm that microJAX imports cleanly, JAX can see your devices, and 64-bit mode is enabled for better numerical stability:
import jax
import jax.numpy as jnp
import microjax
from microjax.point_source import mag_point_source
jax.config.update("jax_enable_x64", True) # recommended for microlensing
print("microJAX", microjax.__version__)
print("Devices", jax.devices())
w = jnp.linspace(-0.3, 0.3, 5) + 0.1j
print("Sample magnification", mag_point_source(w, nlenses=2, s=1.0, q=1e-3))
Up next
Usage Guide walks through binary and triple lens examples.
Troubleshooting lists common pitfalls and quick fixes.
API Reference provides API-level details for every public entry point.