Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.

In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.

This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.

Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.

Fixes https://github.com/google/jax/issues/2322

PiperOrigin-RevId: 449033350
This commit is contained in:
Peter Hawkins 2022-05-16 12:59:27 -07:00 committed by jax authors
parent 744f6b4ee8
commit 909c0328b0
5 changed files with 201 additions and 89 deletions

View File

@ -38,6 +38,7 @@ from jax._src.lax.lax import (
_input_dtype)
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
import jax._src.lib
from jax._src.lib import lapack
from jax._src.lib import gpu_linalg
@ -1265,6 +1266,173 @@ def lu_solve(lu, permutation, b, trans=0):
# QR decomposition
# QR decomposition is implemented as a composition of two lower-level primitives
# geqrf and orgqr. The names, while cryptic Fortran alphabet soup, are LAPACK's
# names for the primitives, and we stick with them for consistency.
def geqrf(a):
"""Computes the QR decomposition of a matrix.
Args:
a: an ``[..., m, n]`` batch of matrices, with floating-point or complex type.
Returns:
An ``(a, taus)`` pair where ``r`` is in the upper triangle of ``a``,
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
elementary Householder reflectors.
"""
a_out, taus = geqrf_p.bind(a)
return a_out, taus
def _geqrf_abstract_eval(operand):
if not isinstance(operand, ShapedArray):
raise NotImplementedError("Unsupported aval in geqrf_abstract_eval: "
f"{operand.aval}")
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
taus = operand.update(shape=(*batch_dims, min(m, n)))
return operand, taus
def _geqrf_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
return geqrf(batching.moveaxis(x, bd, 0)), (0, 0)
def _geqrf_translation_rule(ctx, avals_in, avals_out, operand):
return xops.QrDecomposition(operand)
def _geqrf_cpu_gpu_lowering(geqrf_impl, ctx, a):
a_aval, taus_aval = ctx.avals_out
*batch_dims, m, n = a_aval.shape
if m == 0 or n == 0:
return mlir.full_like_aval(0, a_aval), mlir.full_like_aval(0, taus_aval)
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED")
ok_a = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
a_out = _broadcasting_select_mhlo(ok_a, a_out, _nan_like_mhlo(a_aval))
ok_taus = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1,),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
taus = _broadcasting_select_mhlo(ok_taus, taus, _nan_like_mhlo(taus_aval))
return a_out, taus
geqrf_p = Primitive('geqrf')
geqrf_p.multiple_results = True
geqrf_p.def_impl(partial(xla.apply_primitive, geqrf_p))
geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
xla.register_translation(geqrf_p, _geqrf_translation_rule)
mlir.register_lowering(
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf),
platform='cuda')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, solver_apis.geqrf_mhlo),
platform='gpu')
# orgqr: product of elementary Householder reflectors
def orgqr(a, taus):
"""Product of elementary Householder reflectors.
Args:
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
elementary Householder reflectors.
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
the scalar factors of the elementary Householder reflectors.
Returns:
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
containing the products of the elementary Householder reflectors.
"""
return orgqr_p.bind(a, taus)
def _orgqr_abstract_eval(a, taus):
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
raise NotImplementedError("Unsupported aval in orgqr_abstract_eval: "
f"{a.aval} {taus.aval}")
if a.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = a.shape
*taus_batch_dims, k = taus.shape
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
raise ValueError(f"Type mismatch for orgqr: a={a} taus={taus}")
return a
def _orgqr_batching_rule(batched_args, batch_dims):
a, taus = batched_args
b_a, b_taus, = batch_dims
return orgqr(batching.moveaxis(a, b_a, 0),
batching.moveaxis(taus, b_taus, 0)), (0,)
def _orgqr_translation_rule(ctx, avals_in, avals_out, a, taus):
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
a_aval, _ = ctx.avals_in
*batch_dims, m, n = a_aval.shape
if m == 0 or n == 0:
return [mlir.full_like_aval(0, a_aval)]
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED")
ok = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get((*batch_dims, 1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
a = _broadcasting_select_mhlo(ok, a, _nan_like_mhlo(a_aval))
return [a]
orgqr_p = Primitive('orgqr')
orgqr_p.def_impl(partial(xla.apply_primitive, orgqr_p))
orgqr_p.def_abstract_eval(_orgqr_abstract_eval)
batching.primitive_batchers[orgqr_p] = _orgqr_batching_rule
xla.register_translation(orgqr_p, _orgqr_translation_rule)
mlir.register_lowering(
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
platform='cuda')
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, solver_apis.orgqr_mhlo),
platform='gpu')
def _qr_impl(operand, *, full_matrices):
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
return q, r
@ -1317,105 +1485,43 @@ def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
x = batching.moveaxis(x, bd, 0)
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
def _empty_qr(a, *, full_matrices):
*batch_shape, m, n = a.shape
k = m if full_matrices else min(m, n)
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_shape, m, k))
r = jnp.empty((*batch_shape, k, n), dtype=a.dtype)
return [q, r]
def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,
full_matrices):
operand_aval, = ctx.avals_in
q_aval, r_aval = ctx.avals_out
dims = operand_aval.shape
m, n = dims[-2:]
batch_dims = dims[:-2]
def _qr_lowering(a, *, full_matrices):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
return mlir.lower_fun(_empty_qr, multiple_results=True)(
ctx, operand, full_matrices=full_matrices)
k = m if full_matrices else min(m, n)
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_dims, m, k))
r = jnp.empty((*batch_dims, k, n), dtype=a.dtype)
return q, r
r, tau, info_geqrf = geqrf_impl(operand_aval.dtype, operand)
r, taus = geqrf(a)
if m < n:
q = mhlo.SliceOp(r,
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements(list(batch_dims) + [m, m]),
mlir.dense_int_elements([1] * len(dims))).result
q, info_orgqr = orgqr_impl(operand_aval.dtype, q, tau)
elif not full_matrices:
q, info_orgqr = orgqr_impl(operand_aval.dtype, r, tau)
r = mhlo.SliceOp(r,
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements(list(batch_dims) + [n, n]),
mlir.dense_int_elements([1] * len(dims))).result
q = orgqr(r[..., :m, :m], taus)
elif full_matrices:
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
q = lax.pad(r, lax_internal._zero(r), pads)
q = orgqr(q, taus)
else:
if jax._src.lib.mlir_api_version < 15:
q = mhlo.PadOp(mlir.aval_to_ir_type(q_aval), r,
mlir.ir_constant(np.array(0, dtype=operand_aval.dtype)),
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements([0] * (len(dims) - 1) + [m - n]),
mlir.dense_int_elements([0] * len(dims))).result
else:
q = mhlo.PadOp(r,
mlir.ir_constant(np.array(0, dtype=operand_aval.dtype)),
mlir.dense_int_elements([0] * len(dims)),
mlir.dense_int_elements([0] * (len(dims) - 1) + [m - n]),
mlir.dense_int_elements([0] * len(dims))).result
q, info_orgqr = orgqr_impl(operand_aval.dtype, q, tau)
if info_geqrf is not None:
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mhlo.AndOp(
mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED"),
mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED"))
ok = mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
q = _broadcasting_select_mhlo(ok, q, _nan_like_mhlo(q_aval))
r = _broadcasting_select_mhlo(ok, r, _nan_like_mhlo(r_aval))
else:
pass # rocsolver does not return info
q = orgqr(r, taus)
r = r[..., :n, :n]
r = jnp.triu(r)
return q, r
sub_ctx = mlir.LoweringRuleContext(module_context=ctx.module_context,
primitive=None,
avals_in=[r_aval],
avals_out=[r_aval],
tokens_in=ctx.tokens_in,
tokens_out=ctx.tokens_out)
r, = mlir.lower_fun(jnp.triu, multiple_results=False)(sub_ctx, r)
return [q, r]
qr_p = Primitive('qr')
qr_p.multiple_results = True
qr_p.def_impl(_qr_impl)
qr_p.def_abstract_eval(_qr_abstract_eval)
xla.register_translation(qr_p, _qr_translation_rule)
# Older jaxlibs didn't expose geqrf and orgqr as separate XLA operations.
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.10.
if jax._src.lib.xla_extension_version < 69:
xla.register_translation(qr_p, _qr_translation_rule, platform="tpu")
ad.primitive_jvps[qr_p] = qr_jvp_rule
batching.primitive_batchers[qr_p] = _qr_batching_rule
mlir.register_lowering(
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
platform='cpu')
mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering));
if gpu_solver is not None:
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
gpu_solver.cuda_orgqr),
platform='cuda')
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
gpu_solver.rocm_orgqr),
platform='rocm')
if solver_apis is not None:
mlir.register_lowering(
qr_p,
partial(_qr_cpu_gpu_lowering, solver_apis.geqrf_mhlo, solver_apis.orgqr_mhlo),
platform='gpu')
# Singular value decomposition

View File

@ -477,13 +477,16 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a, mode="reduced"):
a, = _promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
a, taus = lax_linalg.geqrf(a)
return _T(a), taus
if mode in ("reduced", "r", "full"):
full_matrices = False
elif mode == "complete":
full_matrices = True
else:
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
a, = _promote_dtypes_inexact(jnp.asarray(a))
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
if mode == "r":
return r

View File

@ -1011,6 +1011,8 @@ tf_not_yet_impl = [
"all_gather",
"lu_pivots_to_permutation",
"xla_pmap",
"geqrf",
"orgqr",
]
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient

View File

@ -208,7 +208,7 @@ def orgqr_mhlo(dtype, a, tau):
b *= d
tau_dims = ir.RankedTensorType(tau.type).shape
assert tau_dims[:-1] == dims[:-2]
assert tau_dims[:-1] == dims[:-2], (tau.type, a.type)
k = tau_dims[-1]
if dtype == np.float32:

View File

@ -657,7 +657,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
"shape": shape, "dtype": dtype, "mode": mode}
for shape in [(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)]
for dtype in [np.float32]
for mode in ["reduced", "r", "full", "complete"]))
for mode in ["reduced", "r", "full", "complete", "raw"]))
def testNumpyQrModes(self, shape, dtype, mode):
rng = jtu.rand_default(self.rng())
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
@ -665,7 +665,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
if mode == "full":
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1e-5, atol=1e-5,
check_dtypes=(mode != "raw"))
self._CompileAndCheck(jnp_func, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(