Polynomial Solver
Root finder utilities based on the Ehrlich–Aberth method.
This module implements a JAX-friendly version of the Ehrlich–Aberth (EA)
iteration to find all complex roots of a polynomial given in polyval order.
Design highlights
Iteration uses
jax.lax.while_loop()for early stopping by tolerance.Gradients are provided via implicit differentiation using
custom_jvp, making reverse/forward AD robust regardless of the number of iterations.The derivative polynomial is precomputed from coefficients to avoid constructing it at every step; no reliance on
jnp.polyder.
Notes on numerical behavior
For nearly real coefficients, tiny imaginary parts in the output roots are zeroed (threshold ~
10 * tol) purely for presentation; the solver remains fully complex and differentiable.For ill-conditioned cases (e.g., near-multiple roots), the method can slow down; the implementation falls back to a Newton step when the EA denominator gets too small to improve stability.
- microjax.poly_solver.poly_roots(coeffs: Array, custom_init: bool = False, roots_init: Array | None = None) Array
Find all roots for a batch of polynomials using EA with early stop.
- Parameters:
coeffs (jax.Array) – Coefficient array with shape
[..., n+1]inpolyvalorder. Real or complex values are supported.custom_init (bool, optional) – Pass
Trueto feedroots_initas custom initial guesses.roots_init (Optional[jax.Array], optional) – Initial guesses with shape
[..., n]used whencustom_initisTrue.
- Returns:
Complex roots with shape
[..., n]in unspecified order.- Return type:
jax.Array
- microjax.poly_solver.poly_roots_EA_multi(coeffs_matrix: Array, custom_init: bool = False, initial_roots_matrix: Array | None = None, tol: float = 1e-12, max_iter: int = 100) Array
Vectorized EA solver for batches of polynomials.
- Parameters:
coeffs_matrix (jax.Array) – Coefficient array with shape
[..., n+1].custom_init (bool, optional) – Use
initial_roots_matrixas initial guesses whenTrue.initial_roots_matrix (Optional[jax.Array], optional) – Initial guesses with shape
[..., n]ifcustom_initisTrue.tol (float, optional) – Absolute update tolerance per root; forwarded to
poly_roots_EA(). Defaults to1e-12.max_iter (int, optional) – Maximum number of iterations per root; forwarded to
poly_roots_EA(). Defaults to_DEFAULT_MAX_ITER.
- Returns:
Roots with shape
[..., n].- Return type:
jax.Array