Polynomial Solver

Root finder utilities based on the Ehrlich–Aberth method.

This module implements a JAX-friendly version of the Ehrlich–Aberth (EA) iteration to find all complex roots of a polynomial given in polyval order.

Design highlights

  • Iteration uses jax.lax.while_loop() for early stopping by tolerance.

  • Gradients are provided via implicit differentiation using custom_jvp, making reverse/forward AD robust regardless of the number of iterations.

  • The derivative polynomial is precomputed from coefficients to avoid constructing it at every step; no reliance on jnp.polyder.

Notes on numerical behavior

  • For nearly real coefficients, tiny imaginary parts in the output roots are zeroed (threshold ~ 10 * tol) purely for presentation; the solver remains fully complex and differentiable.

  • For ill-conditioned cases (e.g., near-multiple roots), the method can slow down; the implementation falls back to a Newton step when the EA denominator gets too small to improve stability.

microjax.poly_solver.poly_roots(coeffs: Array, custom_init: bool = False, roots_init: Optional[Array] = None) Array

Find all roots for a batch of polynomials using EA with early stop.

Parameters
  • coeffs (jax.Array) – Coefficient array with shape [..., n+1] in polyval order. Real or complex values are supported.

  • custom_init (bool, optional) – Pass True to feed roots_init as custom initial guesses.

  • roots_init (Optional[jax.Array], optional) – Initial guesses with shape [..., n] used when custom_init is True.

Returns

Complex roots with shape [..., n] in unspecified order.

Return type

jax.Array

microjax.poly_solver.poly_roots_EA_multi(coeffs_matrix: Array, custom_init: bool = False, initial_roots_matrix: Optional[Array] = None, tol: float = 1e-12, max_iter: int = 100) Array

Vectorized EA solver for batches of polynomials.

Parameters
  • coeffs_matrix (jax.Array) – Coefficient array with shape [..., n+1].

  • custom_init (bool, optional) – Use initial_roots_matrix as initial guesses when True.

  • initial_roots_matrix (Optional[jax.Array], optional) – Initial guesses with shape [..., n] if custom_init is True.

  • tol (float, optional) – Absolute update tolerance per root; forwarded to poly_roots_EA(). Defaults to 1e-12.

  • max_iter (int, optional) – Maximum number of iterations per root; forwarded to poly_roots_EA(). Defaults to _DEFAULT_MAX_ITER.

Returns

Roots with shape [..., n].

Return type

jax.Array