mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Adding info to CG and BICGSTAB
This commit is contained in:
parent
12f7cdeeae
commit
73ed511d39
@ -133,9 +133,9 @@ def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
gamma0 = _vdot_real_tree(r0, z0).astype(dtype)
|
||||
initial_value = (x0, r0, gamma0, p0, 0)
|
||||
|
||||
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
|
||||
return x_final
|
||||
x_final, r, gamma, _, k = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
|
||||
return x_final, (k, rs)
|
||||
|
||||
|
||||
# aliases for working with pytrees
|
||||
@ -182,9 +182,9 @@ def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
1, *dtypes._lattice_result_type(*tree_leaves(b)))
|
||||
initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)
|
||||
|
||||
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
|
||||
return x_final
|
||||
x_final, r, *_, k = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
rs = _vdot_real_tree(r, r)
|
||||
return x_final, (k, rs)
|
||||
|
||||
|
||||
def _shapes(pytree):
|
||||
@ -192,7 +192,7 @@ def _shapes(pytree):
|
||||
|
||||
|
||||
def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
|
||||
maxiter=None, M=None, check_symmetric=False):
|
||||
maxiter=None, M=None, check_symmetric=False, has_info=False):
|
||||
if x0 is None:
|
||||
x0 = tree_map(jnp.zeros_like, b)
|
||||
|
||||
@ -225,11 +225,14 @@ def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
|
||||
return not issubclass(x.dtype.type, np.complexfloating)
|
||||
symmetric = all(map(real_valued, tree_leaves(b))) \
|
||||
if check_symmetric else False
|
||||
x = lax.custom_linear_solve(
|
||||
|
||||
x_maybe_info = lax.custom_linear_solve(
|
||||
A, b, solve=isolve_solve, transpose_solve=isolve_solve,
|
||||
symmetric=symmetric)
|
||||
info = None
|
||||
return x, info
|
||||
symmetric=symmetric, has_aux=has_info)
|
||||
if has_info:
|
||||
return x_maybe_info
|
||||
else:
|
||||
return x_maybe_info, None
|
||||
|
||||
|
||||
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
@ -259,9 +262,8 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
-------
|
||||
x : array or tree of arrays
|
||||
The converged solution. Has the same structure as ``b``.
|
||||
info : None
|
||||
Placeholder for convergence information. In the future, JAX will report
|
||||
the number of iterations when convergence is not achieved, like SciPy.
|
||||
info : A a pair with the number of iterations run until convergence and the
|
||||
squared norm of the residual at the last iteration.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
@ -287,7 +289,7 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
"""
|
||||
return _isolve(_cg_solve,
|
||||
A=A, b=b, x0=x0, tol=tol, atol=atol,
|
||||
maxiter=maxiter, M=M, check_symmetric=True)
|
||||
maxiter=maxiter, M=M, check_symmetric=True, has_info=True)
|
||||
|
||||
|
||||
def _safe_normalize(x, thresh=None):
|
||||
@ -734,9 +736,8 @@ def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
-------
|
||||
x : array or tree of arrays
|
||||
The converged solution. Has the same structure as ``b``.
|
||||
info : None
|
||||
Placeholder for convergence information. In the future, JAX will report
|
||||
the number of iterations when convergence is not achieved, like SciPy.
|
||||
info : A a pair with the number of iterations run until convergence and the
|
||||
squared norm of the residual at the last iteration.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
@ -763,4 +764,4 @@ def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
|
||||
return _isolve(_bicgstab_solve,
|
||||
A=A, b=b, x0=x0, tol=tol, atol=atol,
|
||||
maxiter=maxiter, M=M)
|
||||
maxiter=maxiter, M=M, has_info=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user