mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead. This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation. Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition. Fixes https://github.com/google/jax/issues/2322 PiperOrigin-RevId: 449033350
This commit is contained in:
parent
744f6b4ee8
commit
909c0328b0
@ -38,6 +38,7 @@ from jax._src.lax.lax import (
|
||||
_input_dtype)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
import jax._src.lib
|
||||
from jax._src.lib import lapack
|
||||
|
||||
from jax._src.lib import gpu_linalg
|
||||
@ -1265,6 +1266,173 @@ def lu_solve(lu, permutation, b, trans=0):
|
||||
|
||||
# QR decomposition
|
||||
|
||||
# QR decomposition is implemented as a composition of two lower-level primitives
|
||||
# geqrf and orgqr. The names, while cryptic Fortran alphabet soup, are LAPACK's
|
||||
# names for the primitives, and we stick with them for consistency.
|
||||
|
||||
def geqrf(a):
|
||||
"""Computes the QR decomposition of a matrix.
|
||||
|
||||
Args:
|
||||
a: an ``[..., m, n]`` batch of matrices, with floating-point or complex type.
|
||||
Returns:
|
||||
An ``(a, taus)`` pair where ``r`` is in the upper triangle of ``a``,
|
||||
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
|
||||
elementary Householder reflectors.
|
||||
"""
|
||||
a_out, taus = geqrf_p.bind(a)
|
||||
return a_out, taus
|
||||
|
||||
def _geqrf_abstract_eval(operand):
|
||||
if not isinstance(operand, ShapedArray):
|
||||
raise NotImplementedError("Unsupported aval in geqrf_abstract_eval: "
|
||||
f"{operand.aval}")
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
*batch_dims, m, n = operand.shape
|
||||
taus = operand.update(shape=(*batch_dims, min(m, n)))
|
||||
return operand, taus
|
||||
|
||||
def _geqrf_batching_rule(batched_args, batch_dims):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
return geqrf(batching.moveaxis(x, bd, 0)), (0, 0)
|
||||
|
||||
def _geqrf_translation_rule(ctx, avals_in, avals_out, operand):
|
||||
return xops.QrDecomposition(operand)
|
||||
|
||||
def _geqrf_cpu_gpu_lowering(geqrf_impl, ctx, a):
|
||||
a_aval, taus_aval = ctx.avals_out
|
||||
*batch_dims, m, n = a_aval.shape
|
||||
|
||||
if m == 0 or n == 0:
|
||||
return mlir.full_like_aval(0, a_aval), mlir.full_like_aval(0, taus_aval)
|
||||
|
||||
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
|
||||
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED")
|
||||
ok_a = mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get((*batch_dims, 1, 1),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
|
||||
a_out = _broadcasting_select_mhlo(ok_a, a_out, _nan_like_mhlo(a_aval))
|
||||
ok_taus = mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get((*batch_dims, 1,),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
|
||||
taus = _broadcasting_select_mhlo(ok_taus, taus, _nan_like_mhlo(taus_aval))
|
||||
return a_out, taus
|
||||
|
||||
geqrf_p = Primitive('geqrf')
|
||||
geqrf_p.multiple_results = True
|
||||
geqrf_p.def_impl(partial(xla.apply_primitive, geqrf_p))
|
||||
geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
|
||||
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
|
||||
xla.register_translation(geqrf_p, _geqrf_translation_rule)
|
||||
|
||||
mlir.register_lowering(
|
||||
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo),
|
||||
platform='cpu')
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, gpu_solver.rocm_geqrf),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, solver_apis.geqrf_mhlo),
|
||||
platform='gpu')
|
||||
|
||||
|
||||
# orgqr: product of elementary Householder reflectors
|
||||
|
||||
def orgqr(a, taus):
|
||||
"""Product of elementary Householder reflectors.
|
||||
|
||||
Args:
|
||||
a: A matrix with shape ``[..., m, n]``, whose lower triangle contains
|
||||
elementary Householder reflectors.
|
||||
taus: A vector with shape ``[..., k]``, where ``k < min(m, n)``, containing
|
||||
the scalar factors of the elementary Householder reflectors.
|
||||
|
||||
Returns:
|
||||
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
||||
containing the products of the elementary Householder reflectors.
|
||||
"""
|
||||
return orgqr_p.bind(a, taus)
|
||||
|
||||
|
||||
def _orgqr_abstract_eval(a, taus):
|
||||
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
|
||||
raise NotImplementedError("Unsupported aval in orgqr_abstract_eval: "
|
||||
f"{a.aval} {taus.aval}")
|
||||
if a.ndim < 2:
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
*batch_dims, m, n = a.shape
|
||||
*taus_batch_dims, k = taus.shape
|
||||
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
|
||||
raise ValueError(f"Type mismatch for orgqr: a={a} taus={taus}")
|
||||
return a
|
||||
|
||||
def _orgqr_batching_rule(batched_args, batch_dims):
|
||||
a, taus = batched_args
|
||||
b_a, b_taus, = batch_dims
|
||||
return orgqr(batching.moveaxis(a, b_a, 0),
|
||||
batching.moveaxis(taus, b_taus, 0)), (0,)
|
||||
|
||||
def _orgqr_translation_rule(ctx, avals_in, avals_out, a, taus):
|
||||
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
|
||||
|
||||
def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||
a_aval, _ = ctx.avals_in
|
||||
*batch_dims, m, n = a_aval.shape
|
||||
|
||||
if m == 0 or n == 0:
|
||||
return [mlir.full_like_aval(0, a_aval)]
|
||||
|
||||
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
|
||||
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED")
|
||||
ok = mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get((*batch_dims, 1, 1),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
|
||||
a = _broadcasting_select_mhlo(ok, a, _nan_like_mhlo(a_aval))
|
||||
return [a]
|
||||
|
||||
|
||||
orgqr_p = Primitive('orgqr')
|
||||
orgqr_p.def_impl(partial(xla.apply_primitive, orgqr_p))
|
||||
orgqr_p.def_abstract_eval(_orgqr_abstract_eval)
|
||||
batching.primitive_batchers[orgqr_p] = _orgqr_batching_rule
|
||||
xla.register_translation(orgqr_p, _orgqr_translation_rule)
|
||||
|
||||
mlir.register_lowering(
|
||||
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
|
||||
platform='cpu')
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
||||
platform='rocm')
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, solver_apis.orgqr_mhlo),
|
||||
platform='gpu')
|
||||
|
||||
|
||||
def _qr_impl(operand, *, full_matrices):
|
||||
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
||||
return q, r
|
||||
@ -1317,105 +1485,43 @@ def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return qr_p.bind(x, full_matrices=full_matrices), (0, 0)
|
||||
|
||||
def _empty_qr(a, *, full_matrices):
|
||||
*batch_shape, m, n = a.shape
|
||||
k = m if full_matrices else min(m, n)
|
||||
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_shape, m, k))
|
||||
r = jnp.empty((*batch_shape, k, n), dtype=a.dtype)
|
||||
return [q, r]
|
||||
|
||||
def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,
|
||||
full_matrices):
|
||||
operand_aval, = ctx.avals_in
|
||||
q_aval, r_aval = ctx.avals_out
|
||||
dims = operand_aval.shape
|
||||
m, n = dims[-2:]
|
||||
batch_dims = dims[:-2]
|
||||
|
||||
def _qr_lowering(a, *, full_matrices):
|
||||
*batch_dims, m, n = a.shape
|
||||
if m == 0 or n == 0:
|
||||
return mlir.lower_fun(_empty_qr, multiple_results=True)(
|
||||
ctx, operand, full_matrices=full_matrices)
|
||||
k = m if full_matrices else min(m, n)
|
||||
q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_dims, m, k))
|
||||
r = jnp.empty((*batch_dims, k, n), dtype=a.dtype)
|
||||
return q, r
|
||||
|
||||
r, tau, info_geqrf = geqrf_impl(operand_aval.dtype, operand)
|
||||
r, taus = geqrf(a)
|
||||
if m < n:
|
||||
q = mhlo.SliceOp(r,
|
||||
mlir.dense_int_elements([0] * len(dims)),
|
||||
mlir.dense_int_elements(list(batch_dims) + [m, m]),
|
||||
mlir.dense_int_elements([1] * len(dims))).result
|
||||
q, info_orgqr = orgqr_impl(operand_aval.dtype, q, tau)
|
||||
elif not full_matrices:
|
||||
q, info_orgqr = orgqr_impl(operand_aval.dtype, r, tau)
|
||||
r = mhlo.SliceOp(r,
|
||||
mlir.dense_int_elements([0] * len(dims)),
|
||||
mlir.dense_int_elements(list(batch_dims) + [n, n]),
|
||||
mlir.dense_int_elements([1] * len(dims))).result
|
||||
q = orgqr(r[..., :m, :m], taus)
|
||||
elif full_matrices:
|
||||
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
|
||||
q = lax.pad(r, lax_internal._zero(r), pads)
|
||||
q = orgqr(q, taus)
|
||||
else:
|
||||
if jax._src.lib.mlir_api_version < 15:
|
||||
q = mhlo.PadOp(mlir.aval_to_ir_type(q_aval), r,
|
||||
mlir.ir_constant(np.array(0, dtype=operand_aval.dtype)),
|
||||
mlir.dense_int_elements([0] * len(dims)),
|
||||
mlir.dense_int_elements([0] * (len(dims) - 1) + [m - n]),
|
||||
mlir.dense_int_elements([0] * len(dims))).result
|
||||
else:
|
||||
q = mhlo.PadOp(r,
|
||||
mlir.ir_constant(np.array(0, dtype=operand_aval.dtype)),
|
||||
mlir.dense_int_elements([0] * len(dims)),
|
||||
mlir.dense_int_elements([0] * (len(dims) - 1) + [m - n]),
|
||||
mlir.dense_int_elements([0] * len(dims))).result
|
||||
q, info_orgqr = orgqr_impl(operand_aval.dtype, q, tau)
|
||||
if info_geqrf is not None:
|
||||
zeros = mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mhlo.AndOp(
|
||||
mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED"),
|
||||
mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED"))
|
||||
ok = mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get(batch_dims + (1, 1),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok, mlir.dense_int_elements(range(len(batch_dims)))).result
|
||||
q = _broadcasting_select_mhlo(ok, q, _nan_like_mhlo(q_aval))
|
||||
r = _broadcasting_select_mhlo(ok, r, _nan_like_mhlo(r_aval))
|
||||
else:
|
||||
pass # rocsolver does not return info
|
||||
q = orgqr(r, taus)
|
||||
r = r[..., :n, :n]
|
||||
r = jnp.triu(r)
|
||||
return q, r
|
||||
|
||||
sub_ctx = mlir.LoweringRuleContext(module_context=ctx.module_context,
|
||||
primitive=None,
|
||||
avals_in=[r_aval],
|
||||
avals_out=[r_aval],
|
||||
tokens_in=ctx.tokens_in,
|
||||
tokens_out=ctx.tokens_out)
|
||||
r, = mlir.lower_fun(jnp.triu, multiple_results=False)(sub_ctx, r)
|
||||
return [q, r]
|
||||
|
||||
qr_p = Primitive('qr')
|
||||
qr_p.multiple_results = True
|
||||
qr_p.def_impl(_qr_impl)
|
||||
qr_p.def_abstract_eval(_qr_abstract_eval)
|
||||
xla.register_translation(qr_p, _qr_translation_rule)
|
||||
|
||||
# Older jaxlibs didn't expose geqrf and orgqr as separate XLA operations.
|
||||
# TODO(phawkins): remove after minimum jaxlib version is > 0.3.10.
|
||||
if jax._src.lib.xla_extension_version < 69:
|
||||
xla.register_translation(qr_p, _qr_translation_rule, platform="tpu")
|
||||
|
||||
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
||||
batching.primitive_batchers[qr_p] = _qr_batching_rule
|
||||
|
||||
mlir.register_lowering(
|
||||
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(qr_p, mlir.lower_fun(_qr_lowering));
|
||||
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
qr_p,
|
||||
partial(_qr_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
|
||||
gpu_solver.cuda_orgqr),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
qr_p,
|
||||
partial(_qr_cpu_gpu_lowering, gpu_solver.rocm_geqrf,
|
||||
gpu_solver.rocm_orgqr),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
if solver_apis is not None:
|
||||
mlir.register_lowering(
|
||||
qr_p,
|
||||
partial(_qr_cpu_gpu_lowering, solver_apis.geqrf_mhlo, solver_apis.orgqr_mhlo),
|
||||
platform='gpu')
|
||||
|
||||
# Singular value decomposition
|
||||
|
||||
|
@ -477,13 +477,16 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
|
||||
@_wraps(np.linalg.qr)
|
||||
@partial(jit, static_argnames=('mode',))
|
||||
def qr(a, mode="reduced"):
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
if mode == "raw":
|
||||
a, taus = lax_linalg.geqrf(a)
|
||||
return _T(a), taus
|
||||
if mode in ("reduced", "r", "full"):
|
||||
full_matrices = False
|
||||
elif mode == "complete":
|
||||
full_matrices = True
|
||||
else:
|
||||
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
|
||||
if mode == "r":
|
||||
return r
|
||||
|
@ -1011,6 +1011,8 @@ tf_not_yet_impl = [
|
||||
"all_gather",
|
||||
"lu_pivots_to_permutation",
|
||||
"xla_pmap",
|
||||
"geqrf",
|
||||
"orgqr",
|
||||
]
|
||||
|
||||
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
|
||||
|
@ -208,7 +208,7 @@ def orgqr_mhlo(dtype, a, tau):
|
||||
b *= d
|
||||
|
||||
tau_dims = ir.RankedTensorType(tau.type).shape
|
||||
assert tau_dims[:-1] == dims[:-2]
|
||||
assert tau_dims[:-1] == dims[:-2], (tau.type, a.type)
|
||||
k = tau_dims[-1]
|
||||
|
||||
if dtype == np.float32:
|
||||
|
@ -657,7 +657,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
"shape": shape, "dtype": dtype, "mode": mode}
|
||||
for shape in [(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)]
|
||||
for dtype in [np.float32]
|
||||
for mode in ["reduced", "r", "full", "complete"]))
|
||||
for mode in ["reduced", "r", "full", "complete", "raw"]))
|
||||
def testNumpyQrModes(self, shape, dtype, mode):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
|
||||
@ -665,7 +665,8 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
if mode == "full":
|
||||
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1e-5, atol=1e-5,
|
||||
check_dtypes=(mode != "raw"))
|
||||
self._CompileAndCheck(jnp_func, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user