mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
This commit is contained in:
parent
b67e42a373
commit
cf65f6b24e
@ -58,8 +58,8 @@ def eigh(x, lower=True, symmetrize_input=True):
|
||||
return v, w
|
||||
|
||||
def lu(x):
|
||||
lu, pivots = lu_p.bind(x)
|
||||
return lu, pivots
|
||||
lu, pivots, permutation = lu_p.bind(x)
|
||||
return lu, pivots, permutation
|
||||
|
||||
def qr(x, full_matrices=True):
|
||||
q, r = qr_p.bind(x, full_matrices=full_matrices)
|
||||
@ -534,13 +534,15 @@ def _lu_blocked(a, block_size=128):
|
||||
m, n = a.shape
|
||||
r = min(m, n)
|
||||
pivot = jnp.zeros((r,), dtype=jnp.int32)
|
||||
perm = jnp.arange(m, dtype=jnp.int32)
|
||||
for k in range(0, r, block_size):
|
||||
b = min(r - k, block_size)
|
||||
block_pivot, perm, lu_block = _lu_unblocked(a[k:, k:k+b])
|
||||
block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b])
|
||||
|
||||
a = ops.index_update(a, ops.index[k:, :], a[perm + k, :])
|
||||
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
|
||||
pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k)
|
||||
perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k])
|
||||
a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :])
|
||||
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
|
||||
|
||||
if k + b < n:
|
||||
a = ops.index_update(
|
||||
@ -551,7 +553,7 @@ def _lu_blocked(a, block_size=128):
|
||||
a, ops.index[k+b:, k+b:],
|
||||
-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
|
||||
precision=lax.Precision.HIGHEST))
|
||||
return pivot, a
|
||||
return a, pivot, perm
|
||||
|
||||
def _lu_python(x):
|
||||
"""Default LU decomposition in Python, where no better version exists."""
|
||||
@ -559,16 +561,17 @@ def _lu_python(x):
|
||||
batch_dims = x.shape[:-2]
|
||||
if len(batch_dims) > 0:
|
||||
batch_size = np.prod(batch_dims, dtype=np.int64)
|
||||
pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
|
||||
pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
|
||||
lu, pivot, perm = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
|
||||
lu = lax.reshape(lu, batch_dims + (m, n))
|
||||
pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
|
||||
perm = lax.reshape(perm, batch_dims + (m,))
|
||||
else:
|
||||
pivot, lu = _lu_blocked(x)
|
||||
return lu, pivot
|
||||
lu, pivot, perm = _lu_blocked(x)
|
||||
return lu, pivot, perm
|
||||
|
||||
def _lu_impl(operand):
|
||||
lu, pivot = xla.apply_primitive(lu_p, operand)
|
||||
return lu, pivot
|
||||
lu, pivot, perm = xla.apply_primitive(lu_p, operand)
|
||||
return lu, pivot, perm
|
||||
|
||||
def _lu_abstract_eval(operand):
|
||||
if isinstance(operand, ShapedArray):
|
||||
@ -579,21 +582,22 @@ def _lu_abstract_eval(operand):
|
||||
m = operand.shape[-2]
|
||||
n = operand.shape[-1]
|
||||
pivot = ShapedArray(batch_dims + (min(m, n),), jnp.int32)
|
||||
perm = ShapedArray(batch_dims + (m,), jnp.int32)
|
||||
else:
|
||||
pivot = operand
|
||||
return operand, pivot
|
||||
perm = operand
|
||||
return operand, pivot, perm
|
||||
|
||||
def _lu_jvp_rule(primals, tangents):
|
||||
a, = primals
|
||||
a_dot, = tangents
|
||||
lu, pivots = lu_p.bind(a)
|
||||
lu, pivots, permutation = lu_p.bind(a)
|
||||
|
||||
a_shape = jnp.shape(a)
|
||||
m, n = a_shape[-2:]
|
||||
dtype = lax.dtype(a)
|
||||
k = min(m, n)
|
||||
|
||||
permutation = lu_pivots_to_permutation(pivots, m)
|
||||
batch_dims = a_shape[:-2]
|
||||
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
||||
x = a_dot[iotas[:-1] + (permutation, slice(None))]
|
||||
@ -628,25 +632,29 @@ def _lu_jvp_rule(primals, tangents):
|
||||
l_dot = jnp.matmul(l, jnp.tril(lau, -1))
|
||||
u_dot = jnp.matmul(jnp.triu(lau), u)
|
||||
lu_dot = l_dot + u_dot
|
||||
return (lu, pivots), (lu_dot, ad_util.Zero.from_value(pivots))
|
||||
return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
|
||||
ad_util.Zero.from_value(permutation))
|
||||
|
||||
|
||||
def _lu_batching_rule(batched_args, batch_dims):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return lu_p.bind(x), (0, 0)
|
||||
return lu_p.bind(x), (0, 0, 0)
|
||||
|
||||
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
|
||||
shape = c.get_shape(operand)
|
||||
batch_dims = shape.dimensions()[:-2]
|
||||
m = shape.dimensions()[-2]
|
||||
lu, pivot, info = getrf_impl(c, operand)
|
||||
# Subtract 1 from the pivot to get 0-based indices.
|
||||
pivot = xops.Sub(pivot, xops.ConstantLiteral(c, np.array(1, np.int32)))
|
||||
ok = xops.Ge(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
|
||||
lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu,
|
||||
_nan_like(c, lu))
|
||||
return xops.Tuple(c, [lu, pivot])
|
||||
perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
|
||||
multiple_results=False)(c, pivot)
|
||||
return xops.Tuple(c, [lu, pivot, perm])
|
||||
|
||||
|
||||
lu_p = Primitive('lu')
|
||||
@ -695,15 +703,16 @@ def lu_pivots_to_permutation(swaps, m):
|
||||
|
||||
permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,),
|
||||
len(batch_dims))
|
||||
if m == 0:
|
||||
return permutation
|
||||
result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32),
|
||||
_lu_pivots_body_fn, (permutation, swaps))
|
||||
return result
|
||||
|
||||
|
||||
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
|
||||
def _lu_solve_core(lu, pivots, b, trans):
|
||||
def _lu_solve_core(lu, permutation, b, trans):
|
||||
m = lu.shape[0]
|
||||
permutation = lu_pivots_to_permutation(pivots, m)
|
||||
x = jnp.reshape(b, (m, -1))
|
||||
if trans == 0:
|
||||
x = x[permutation, :]
|
||||
@ -722,7 +731,7 @@ def _lu_solve_core(lu, pivots, b, trans):
|
||||
|
||||
|
||||
@partial(api.jit, static_argnums=(3,))
|
||||
def _lu_solve(lu, pivots, b, trans):
|
||||
def _lu_solve(lu, permutation, b, trans):
|
||||
if len(lu.shape) < 2 or lu.shape[-1] != lu.shape[-2]:
|
||||
raise ValueError("last two dimensions of LU decomposition must be equal, "
|
||||
"got shape {}".format(lu.shape))
|
||||
@ -747,13 +756,13 @@ def _lu_solve(lu, pivots, b, trans):
|
||||
"matrix (shape {}) and second to last axis of b array "
|
||||
"(shape {}) must match"
|
||||
.format(lu.shape, b.shape))
|
||||
x = _lu_solve_core(lu, pivots, b, trans)
|
||||
x = _lu_solve_core(lu, permutation, b, trans)
|
||||
return x[..., 0] if rhs_vector else x
|
||||
|
||||
|
||||
def lu_solve(lu, pivots, b, trans=0):
|
||||
def lu_solve(lu, permutation, b, trans=0):
|
||||
"""LU solve with broadcasting."""
|
||||
return _lu_solve(lu, pivots, b, trans)
|
||||
return _lu_solve(lu, permutation, b, trans)
|
||||
|
||||
|
||||
# QR decomposition
|
||||
|
@ -121,7 +121,7 @@ def slogdet(a):
|
||||
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
|
||||
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
|
||||
raise ValueError(msg.format(a_shape))
|
||||
lu, pivot = lax_linalg.lu(a)
|
||||
lu, pivot, _ = lax_linalg.lu(a)
|
||||
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
|
||||
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
|
||||
parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
|
||||
@ -206,7 +206,7 @@ def _cofactor_solve(a, b):
|
||||
# lu contains u in the upper triangular matrix and l in the strict lower
|
||||
# triangular matrix.
|
||||
# The diagonal of l is set to ones without loss of generality.
|
||||
lu, pivots = lax_linalg.lu(a)
|
||||
lu, pivots, permutation = lax_linalg.lu(a)
|
||||
dtype = lax.dtype(a)
|
||||
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
|
||||
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
|
||||
@ -219,7 +219,6 @@ def _cofactor_solve(a, b):
|
||||
# partial_det[:, -2] contains det(u) / u_{nn}.
|
||||
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
|
||||
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
|
||||
permutation = lax_linalg.lu_pivots_to_permutation(pivots, a_shape[-1])
|
||||
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
|
||||
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
|
||||
# filter out any matrices that are not full rank
|
||||
@ -465,12 +464,13 @@ def solve(a, b):
|
||||
|
||||
# With custom_linear_solve, we can reuse the same factorization when
|
||||
# computing sensitivities. This is considerably faster.
|
||||
lu, pivots = lax_linalg.lu(lax.stop_gradient(a))
|
||||
lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a))
|
||||
custom_solve = partial(
|
||||
lax.custom_linear_solve,
|
||||
lambda x: _matvec_multiply(a, x),
|
||||
solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=0),
|
||||
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=1))
|
||||
solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0),
|
||||
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x,
|
||||
trans=1))
|
||||
if a.ndim == b.ndim + 1:
|
||||
# b.shape == [..., m]
|
||||
return custom_solve(b)
|
||||
|
@ -105,23 +105,25 @@ def inv(a, overwrite_a=False, check_finite=True):
|
||||
def lu_factor(a, overwrite_a=False, check_finite=True):
|
||||
del overwrite_a, check_finite
|
||||
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
|
||||
return lax_linalg.lu(a)
|
||||
lu, pivots, _ = lax_linalg.lu(a)
|
||||
return lu, pivots
|
||||
|
||||
|
||||
@_wraps(scipy.linalg.lu_solve)
|
||||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
|
||||
del overwrite_b, check_finite
|
||||
lu, pivots = lu_and_piv
|
||||
return lax_linalg.lu_solve(lu, pivots, b, trans)
|
||||
m, n = lu.shape[-2:]
|
||||
perm = lax_linalg.lu_pivots_to_permutation(pivots, m)
|
||||
return lax_linalg.lu_solve(lu, perm, b, trans)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(1,))
|
||||
def _lu(a, permute_l):
|
||||
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
|
||||
lu, pivots = lax_linalg.lu(a)
|
||||
lu, pivots, permutation = lax_linalg.lu(a)
|
||||
dtype = lax.dtype(a)
|
||||
m, n = jnp.shape(a)
|
||||
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
|
||||
p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
|
||||
k = min(m, n)
|
||||
l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
|
||||
@ -442,9 +444,10 @@ def expm_frechet_algo_64(A, E):
|
||||
Lv = jnp.select((A_norm_1<=ell_table_61_local99), (Lv3579, Lv3579), Lv13)
|
||||
s = jnp.select((A_norm_1<=ell_table_61_local99), (s3579, s3579), s13)
|
||||
|
||||
lu_piv = lu_factor(-U + V)
|
||||
R = lu_solve(lu_piv, U + V)
|
||||
L = lu_solve(lu_piv, Lu + Lv + _precise_dot((Lu - Lv), R))
|
||||
lu, _, permutation = lax_linalg.lu(-U + V)
|
||||
R = lax_linalg.lu_solve(lu, permutation, U + V, trans=False)
|
||||
L = lax_linalg.lu_solve(lu, permutation, Lu + Lv + _precise_dot((Lu - Lv), R),
|
||||
trans=False)
|
||||
# squaring
|
||||
def my_body_fun(i,my_arg):
|
||||
R, L = my_arg
|
||||
|
Loading…
x
Reference in New Issue
Block a user