Polynomial Solver
Root finder utilities based on the Ehrlich–Aberth method.
This module implements a JAX-friendly version of the Ehrlich–Aberth (EA)
iteration to find all complex roots of a polynomial given in polyval
order.
Design highlights
Iteration uses
jax.lax.while_loop()
for early stopping by tolerance.Gradients are provided via implicit differentiation using
custom_jvp
, making reverse/forward AD robust regardless of the number of iterations.The derivative polynomial is precomputed from coefficients to avoid constructing it at every step; no reliance on
jnp.polyder
.
Notes on numerical behavior
For nearly real coefficients, tiny imaginary parts in the output roots are zeroed (threshold ~
10 * tol
) purely for presentation; the solver remains fully complex and differentiable.For ill-conditioned cases (e.g., near-multiple roots), the method can slow down; the implementation falls back to a Newton step when the EA denominator gets too small to improve stability.
- microjax.poly_solver.poly_roots(coeffs: Array, custom_init: bool = False, roots_init: Optional[Array] = None) Array
Find all roots for a batch of polynomials using EA with early stop.
- Parameters
coeffs (jax.Array) – Coefficient array with shape
[..., n+1]
inpolyval
order. Real or complex values are supported.custom_init (bool, optional) – Pass
True
to feedroots_init
as custom initial guesses.roots_init (Optional[jax.Array], optional) – Initial guesses with shape
[..., n]
used whencustom_init
isTrue
.
- Returns
Complex roots with shape
[..., n]
in unspecified order.- Return type
jax.Array
- microjax.poly_solver.poly_roots_EA_multi(coeffs_matrix: Array, custom_init: bool = False, initial_roots_matrix: Optional[Array] = None, tol: float = 1e-12, max_iter: int = 100) Array
Vectorized EA solver for batches of polynomials.
- Parameters
coeffs_matrix (jax.Array) – Coefficient array with shape
[..., n+1]
.custom_init (bool, optional) – Use
initial_roots_matrix
as initial guesses whenTrue
.initial_roots_matrix (Optional[jax.Array], optional) – Initial guesses with shape
[..., n]
ifcustom_init
isTrue
.tol (float, optional) – Absolute update tolerance per root; forwarded to
poly_roots_EA()
. Defaults to1e-12
.max_iter (int, optional) – Maximum number of iterations per root; forwarded to
poly_roots_EA()
. Defaults to_DEFAULT_MAX_ITER
.
- Returns
Roots with shape
[..., n]
.- Return type
jax.Array