Troubleshooting

The most common issues reported by new users are summarised here together with suggested fixes.

JAX cannot see my GPU

  • Ensure you installed a CUDA/ROCm build of jaxlib that matches your driver version. Follow the official installation matrix.

  • Double-check the environment variables XLA_PYTHON_CLIENT_PREALLOCATE and JAX_PLATFORMS; temporarily set JAX_PLATFORMS=cuda to force GPU usage.

  • On multi-user systems, confirm that you have access to the GPU (nvidia-smi or ROCm equivalents).

mag_binary is slow or runs out of memory

  • mag_binary also works on CPU but is very slow.

  • Trim the inverse-ray grid via r_resolution / th_resolution when you do not need the default 1000×1000 sampling; smaller grids cut both runtime and memory pressure. Increase them only when accuracy demands it.

  • Adjust chunk_size to fit your device. Lower values avoid out-of-memory crashes; raise it gradually if the GPU remains underutilised.

  • Use MAX_FULL_CALLS to cap how many samples fall back to the full image-centred ray shooting routine. Lowering it keeps runtimes bounded, but expect a trade-off in accuracy if many points revert to the hexadecapole approximation.

Gradient computations stall

  • Confirm that jax_enable_x64 is turned on; implicit differentiation through the polynomial solver is numerically sensitive in single precision.

  • Use jax.jit to compile the forward pass before taking gradients; this shortens trace lengths and avoids repeated recompilations.

Import errors for optional dependencies

microJAX only depends on JAX and NumPy at runtime, but some examples pull in matplotlib or seaborn. Install the plotting stack you need manually—for example python -m pip install matplotlib seaborn—before running the demo scripts.

Still stuck?

Open an issue on GitHub with the following information:

  • microJAX version (python -c "import microjax; print(microjax.__version__)")

  • JAX/JAXLIB versions and platform (CPU, CUDA, ROCm)

  • A minimal code snippet reproducing the issue

We are happy to help debug problems and improve the documentation.