Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)

The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
This commit is contained in:
Peter Hawkins 2020-09-10 11:16:35 -04:00 committed by GitHub
parent b67e42a373
commit cf65f6b24e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 37 deletions

View File

@ -58,8 +58,8 @@ def eigh(x, lower=True, symmetrize_input=True):
return v, w
def lu(x):
lu, pivots = lu_p.bind(x)
return lu, pivots
lu, pivots, permutation = lu_p.bind(x)
return lu, pivots, permutation
def qr(x, full_matrices=True):
q, r = qr_p.bind(x, full_matrices=full_matrices)
@ -534,13 +534,15 @@ def _lu_blocked(a, block_size=128):
m, n = a.shape
r = min(m, n)
pivot = jnp.zeros((r,), dtype=jnp.int32)
perm = jnp.arange(m, dtype=jnp.int32)
for k in range(0, r, block_size):
b = min(r - k, block_size)
block_pivot, perm, lu_block = _lu_unblocked(a[k:, k:k+b])
block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b])
a = ops.index_update(a, ops.index[k:, :], a[perm + k, :])
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
pivot = ops.index_update(pivot, ops.index[k:k+b], block_pivot + k)
perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k])
a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :])
a = ops.index_update(a, ops.index[k:, k:k+b], lu_block)
if k + b < n:
a = ops.index_update(
@ -551,7 +553,7 @@ def _lu_blocked(a, block_size=128):
a, ops.index[k+b:, k+b:],
-lax.dot(a[k+b:, k:k+b], a[k:k+b, k+b:],
precision=lax.Precision.HIGHEST))
return pivot, a
return a, pivot, perm
def _lu_python(x):
"""Default LU decomposition in Python, where no better version exists."""
@ -559,16 +561,17 @@ def _lu_python(x):
batch_dims = x.shape[:-2]
if len(batch_dims) > 0:
batch_size = np.prod(batch_dims, dtype=np.int64)
pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
lu, pivot, perm = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n)))
lu = lax.reshape(lu, batch_dims + (m, n))
pivot = lax.reshape(pivot, batch_dims + (min(m, n),))
perm = lax.reshape(perm, batch_dims + (m,))
else:
pivot, lu = _lu_blocked(x)
return lu, pivot
lu, pivot, perm = _lu_blocked(x)
return lu, pivot, perm
def _lu_impl(operand):
lu, pivot = xla.apply_primitive(lu_p, operand)
return lu, pivot
lu, pivot, perm = xla.apply_primitive(lu_p, operand)
return lu, pivot, perm
def _lu_abstract_eval(operand):
if isinstance(operand, ShapedArray):
@ -579,21 +582,22 @@ def _lu_abstract_eval(operand):
m = operand.shape[-2]
n = operand.shape[-1]
pivot = ShapedArray(batch_dims + (min(m, n),), jnp.int32)
perm = ShapedArray(batch_dims + (m,), jnp.int32)
else:
pivot = operand
return operand, pivot
perm = operand
return operand, pivot, perm
def _lu_jvp_rule(primals, tangents):
a, = primals
a_dot, = tangents
lu, pivots = lu_p.bind(a)
lu, pivots, permutation = lu_p.bind(a)
a_shape = jnp.shape(a)
m, n = a_shape[-2:]
dtype = lax.dtype(a)
k = min(m, n)
permutation = lu_pivots_to_permutation(pivots, m)
batch_dims = a_shape[:-2]
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
x = a_dot[iotas[:-1] + (permutation, slice(None))]
@ -628,25 +632,29 @@ def _lu_jvp_rule(primals, tangents):
l_dot = jnp.matmul(l, jnp.tril(lau, -1))
u_dot = jnp.matmul(jnp.triu(lau), u)
lu_dot = l_dot + u_dot
return (lu, pivots), (lu_dot, ad_util.Zero.from_value(pivots))
return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
ad_util.Zero.from_value(permutation))
def _lu_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return lu_p.bind(x), (0, 0)
return lu_p.bind(x), (0, 0, 0)
def _lu_cpu_gpu_translation_rule(getrf_impl, c, operand):
shape = c.get_shape(operand)
batch_dims = shape.dimensions()[:-2]
m = shape.dimensions()[-2]
lu, pivot, info = getrf_impl(c, operand)
# Subtract 1 from the pivot to get 0-based indices.
pivot = xops.Sub(pivot, xops.ConstantLiteral(c, np.array(1, np.int32)))
ok = xops.Ge(info, xops.ConstantLiteral(c, np.array(0, np.int32)))
lu = _broadcasting_select(c, xops.Reshape(ok, batch_dims + (1, 1)), lu,
_nan_like(c, lu))
return xops.Tuple(c, [lu, pivot])
perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
multiple_results=False)(c, pivot)
return xops.Tuple(c, [lu, pivot, perm])
lu_p = Primitive('lu')
@ -695,15 +703,16 @@ def lu_pivots_to_permutation(swaps, m):
permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,),
len(batch_dims))
if m == 0:
return permutation
result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32),
_lu_pivots_body_fn, (permutation, swaps))
return result
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
def _lu_solve_core(lu, pivots, b, trans):
def _lu_solve_core(lu, permutation, b, trans):
m = lu.shape[0]
permutation = lu_pivots_to_permutation(pivots, m)
x = jnp.reshape(b, (m, -1))
if trans == 0:
x = x[permutation, :]
@ -722,7 +731,7 @@ def _lu_solve_core(lu, pivots, b, trans):
@partial(api.jit, static_argnums=(3,))
def _lu_solve(lu, pivots, b, trans):
def _lu_solve(lu, permutation, b, trans):
if len(lu.shape) < 2 or lu.shape[-1] != lu.shape[-2]:
raise ValueError("last two dimensions of LU decomposition must be equal, "
"got shape {}".format(lu.shape))
@ -747,13 +756,13 @@ def _lu_solve(lu, pivots, b, trans):
"matrix (shape {}) and second to last axis of b array "
"(shape {}) must match"
.format(lu.shape, b.shape))
x = _lu_solve_core(lu, pivots, b, trans)
x = _lu_solve_core(lu, permutation, b, trans)
return x[..., 0] if rhs_vector else x
def lu_solve(lu, pivots, b, trans=0):
def lu_solve(lu, permutation, b, trans=0):
"""LU solve with broadcasting."""
return _lu_solve(lu, pivots, b, trans)
return _lu_solve(lu, permutation, b, trans)
# QR decomposition

View File

@ -121,7 +121,7 @@ def slogdet(a):
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
msg = "Argument to slogdet() must have shape [..., n, n], got {}"
raise ValueError(msg.format(a_shape))
lu, pivot = lax_linalg.lu(a)
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
@ -206,7 +206,7 @@ def _cofactor_solve(a, b):
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
lu, pivots = lax_linalg.lu(a)
lu, pivots, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
@ -219,7 +219,6 @@ def _cofactor_solve(a, b):
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
lu = ops.index_update(lu, ops.index[..., -1, -1], 1.0 / partial_det[..., -2])
permutation = lax_linalg.lu_pivots_to_permutation(pivots, a_shape[-1])
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
# filter out any matrices that are not full rank
@ -465,12 +464,13 @@ def solve(a, b):
# With custom_linear_solve, we can reuse the same factorization when
# computing sensitivities. This is considerably faster.
lu, pivots = lax_linalg.lu(lax.stop_gradient(a))
lu, _, permutation = lax_linalg.lu(lax.stop_gradient(a))
custom_solve = partial(
lax.custom_linear_solve,
lambda x: _matvec_multiply(a, x),
solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=0),
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, pivots, x, trans=1))
solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x, trans=0),
transpose_solve=lambda _, x: lax_linalg.lu_solve(lu, permutation, x,
trans=1))
if a.ndim == b.ndim + 1:
# b.shape == [..., m]
return custom_solve(b)

View File

@ -105,23 +105,25 @@ def inv(a, overwrite_a=False, check_finite=True):
def lu_factor(a, overwrite_a=False, check_finite=True):
del overwrite_a, check_finite
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
return lax_linalg.lu(a)
lu, pivots, _ = lax_linalg.lu(a)
return lu, pivots
@_wraps(scipy.linalg.lu_solve)
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
del overwrite_b, check_finite
lu, pivots = lu_and_piv
return lax_linalg.lu_solve(lu, pivots, b, trans)
m, n = lu.shape[-2:]
perm = lax_linalg.lu_pivots_to_permutation(pivots, m)
return lax_linalg.lu_solve(lu, perm, b, trans)
@partial(jit, static_argnums=(1,))
def _lu(a, permute_l):
a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
lu, pivots = lax_linalg.lu(a)
lu, pivots, permutation = lax_linalg.lu(a)
dtype = lax.dtype(a)
m, n = jnp.shape(a)
permutation = lax_linalg.lu_pivots_to_permutation(pivots, m)
p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
k = min(m, n)
l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
@ -442,9 +444,10 @@ def expm_frechet_algo_64(A, E):
Lv = jnp.select((A_norm_1<=ell_table_61_local99), (Lv3579, Lv3579), Lv13)
s = jnp.select((A_norm_1<=ell_table_61_local99), (s3579, s3579), s13)
lu_piv = lu_factor(-U + V)
R = lu_solve(lu_piv, U + V)
L = lu_solve(lu_piv, Lu + Lv + _precise_dot((Lu - Lv), R))
lu, _, permutation = lax_linalg.lu(-U + V)
R = lax_linalg.lu_solve(lu, permutation, U + V, trans=False)
L = lax_linalg.lu_solve(lu, permutation, Lu + Lv + _precise_dot((Lu - Lv), R),
trans=False)
# squaring
def my_body_fun(i,my_arg):
R, L = my_arg