mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[x64] minor weak_type changes to linalg.py
This commit is contained in:
parent
4dc31376b9
commit
f8e18e9a00
@ -34,11 +34,11 @@ _H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2))
|
||||
|
||||
def _promote_arg_dtypes(*args):
|
||||
"""Promotes `args` to a common inexact type."""
|
||||
def _to_inexact_type(type):
|
||||
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
|
||||
inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
|
||||
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
|
||||
args = [lax.convert_element_type(arg, dtype) for arg in args]
|
||||
dtype, weak_type = dtypes._lattice_result_type(*args)
|
||||
if not jnp.issubdtype(dtype, jnp.inexact):
|
||||
dtype, weak_type = jnp.float_, False
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
else:
|
||||
|
@ -533,8 +533,7 @@ def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
||||
carry = (V, H, False, 0)
|
||||
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
|
||||
|
||||
beta_vec = jnp.zeros((restart + 1,), dtype=dtype)
|
||||
beta_vec = beta_vec.at[0].set(residual_norm)
|
||||
beta_vec = jnp.zeros_like(H, shape=(restart + 1,)).at[0].set(residual_norm)
|
||||
y = _lstsq(H.T, beta_vec)
|
||||
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user