[x64] minor weak_type changes to linalg.py

This commit is contained in:
Jake VanderPlas 2021-12-07 16:27:29 -08:00
parent 4dc31376b9
commit f8e18e9a00
2 changed files with 6 additions and 7 deletions

View File

@ -34,11 +34,11 @@ _H = lambda x: jnp.conjugate(jnp.swapaxes(x, -1, -2))
def _promote_arg_dtypes(*args):
"""Promotes `args` to a common inexact type."""
def _to_inexact_type(type):
return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
args = [lax.convert_element_type(arg, dtype) for arg in args]
dtype, weak_type = dtypes._lattice_result_type(*args)
if not jnp.issubdtype(dtype, jnp.inexact):
dtype, weak_type = jnp.float_, False
dtype = dtypes.canonicalize_dtype(dtype)
args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
if len(args) == 1:
return args[0]
else:

View File

@ -533,8 +533,7 @@ def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
carry = (V, H, False, 0)
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
beta_vec = jnp.zeros((restart + 1,), dtype=dtype)
beta_vec = beta_vec.at[0].set(residual_norm)
beta_vec = jnp.zeros_like(H, shape=(restart + 1,)).at[0].set(residual_norm)
y = _lstsq(H.T, beta_vec)
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)