Adding info to CG and BICGSTAB

This commit is contained in:
botev 2023-01-22 21:33:09 +00:00
parent 12f7cdeeae
commit 73ed511d39

View File

@ -133,9 +133,9 @@ def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
gamma0 = _vdot_real_tree(r0, z0).astype(dtype)
initial_value = (x0, r0, gamma0, p0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
return x_final
x_final, r, gamma, _, k = lax.while_loop(cond_fun, body_fun, initial_value)
rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
return x_final, (k, rs)
# aliases for working with pytrees
@ -182,9 +182,9 @@ def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
1, *dtypes._lattice_result_type(*tree_leaves(b)))
initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
return x_final
x_final, r, *_, k = lax.while_loop(cond_fun, body_fun, initial_value)
rs = _vdot_real_tree(r, r)
return x_final, (k, rs)
def _shapes(pytree):
@ -192,7 +192,7 @@ def _shapes(pytree):
def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
maxiter=None, M=None, check_symmetric=False):
maxiter=None, M=None, check_symmetric=False, has_info=False):
if x0 is None:
x0 = tree_map(jnp.zeros_like, b)
@ -225,11 +225,14 @@ def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
return not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b))) \
if check_symmetric else False
x = lax.custom_linear_solve(
x_maybe_info = lax.custom_linear_solve(
A, b, solve=isolve_solve, transpose_solve=isolve_solve,
symmetric=symmetric)
info = None
return x, info
symmetric=symmetric, has_aux=has_info)
if has_info:
return x_maybe_info
else:
return x_maybe_info, None
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
@ -259,9 +262,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
-------
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
info : A a pair with the number of iterations run until convergence and the
squared norm of the residual at the last iteration.
Other Parameters
----------------
@ -287,7 +289,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""
return _isolve(_cg_solve,
A=A, b=b, x0=x0, tol=tol, atol=atol,
maxiter=maxiter, M=M, check_symmetric=True)
maxiter=maxiter, M=M, check_symmetric=True, has_info=True)
def _safe_normalize(x, thresh=None):
@ -734,9 +736,8 @@ def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
-------
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
info : A a pair with the number of iterations run until convergence and the
squared norm of the residual at the last iteration.
Other Parameters
----------------
@ -763,4 +764,4 @@ def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
return _isolve(_bicgstab_solve,
A=A, b=b, x0=x0, tol=tol, atol=atol,
maxiter=maxiter, M=M)
maxiter=maxiter, M=M, has_info=True)