From b1b56ea0b074dfd18789b9a5313316b88fc2a32c Mon Sep 17 00:00:00 2001
From: tttc3 <97948946+tttc3@users.noreply.github.com>
Date: Fri, 10 Jan 2025 15:55:53 +0000
Subject: [PATCH] Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
---
CHANGELOG.md | 3 +
jax/_src/lax/linalg.py | 74 ++++++---
jax/_src/scipy/linalg.py | 4 +-
jaxlib/ffi_helpers.h | 13 ++
jaxlib/gpu/hybrid.cc | 5 +
jaxlib/gpu/hybrid_kernels.cc | 263 +++++++++++++++++++++++++++++++
jaxlib/gpu/hybrid_kernels.h | 1 +
jaxlib/gpu/solver_kernels_ffi.cc | 13 --
tests/linalg_test.py | 11 +-
tests/magma_linalg_test.py | 29 ++++
10 files changed, 379 insertions(+), 37 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 69c88f00f..b29ad3f1a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
+ * {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
+ column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
+ {jax-issue}`#25955` for more details.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py
index 0653cad7a..13f8885c9 100644
--- a/jax/_src/lax/linalg.py
+++ b/jax/_src/lax/linalg.py
@@ -311,20 +311,22 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
@overload
def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True,
- ) -> tuple[Array, Array]:
+ use_magma: bool | None = None) -> tuple[Array, Array]:
...
@overload
def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True,
- ) -> tuple[Array, Array, Array]:
+ use_magma: bool | None = None) -> tuple[Array, Array, Array]:
...
@overload
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
+ use_magma: bool | None = None
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
...
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
+ use_magma: bool | None = None
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
"""QR decomposition.
@@ -341,9 +343,14 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``,
compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P``
is chosen such that the diagonal of ``R`` is non-increasing. Currently
- supported on CPU backends only.
+ supported on CPU and GPU backends only.
full_matrices: Determines if full or reduced matrices are returned; see
below.
+ use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
+ pivoted `qr` factorization is computed using MAGMA. If ``False``, the
+ computation is done using LAPACK on the host CPU. If ``None`` (default),
+ the behavior is controlled by the ``jax_use_magma`` flag. This argument is
+ only used on GPU.
Returns:
A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``.
@@ -357,8 +364,16 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
``full_matrices=False``.
Array ``p`` is an index vector with shape [..., n]
+
+ Notes:
+ - `MAGMA `_ support is experimental - see
+ :func:`jax.lax.linalg.eig` for further assumptions and limitations.
+ - If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will
+ be used if the library can be found, and the input matrix is sufficiently
+ large (has at least 2048 columns).
"""
- q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices)
+ q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
+ use_magma=use_magma)
if pivoting:
return q, r, p[0]
return q, r
@@ -1854,22 +1869,28 @@ mlir.register_lowering(
platform='rocm')
-def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]:
+def geqp3(a: ArrayLike, jpvt: ArrayLike, *,
+ use_magma: bool | None = None) -> tuple[Array, Array, Array]:
"""Computes the column-pivoted QR decomposition of a matrix.
Args:
a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type.
jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type,
+ use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
+ `geqp3` is computed using MAGMA. If ``False``, the computation is done using
+ LAPACK on to the host CPU. If ``None`` (default), the behavior is controlled
+ by the ``jax_use_magma`` flag. This argument is only used on GPU.
Returns:
A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``,
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
such that ``a[:, jpvt] = q @ r``.
"""
- a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt)
+ a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma)
return a_out, jpvt_out, taus
-def _geqp3_abstract_eval(a, jpvt):
+def _geqp3_abstract_eval(a, jpvt, *, use_magma):
+ del use_magma
if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray):
raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: "
f"{a.aval}, {jpvt.aval}")
@@ -1882,25 +1903,37 @@ def _geqp3_abstract_eval(a, jpvt):
taus = a.update(shape=(*batch_dims, core.min_dim(m, n)))
return a, jpvt, taus
-def _geqp3_batching_rule(batched_args, batch_dims):
+def _geqp3_batching_rule(batched_args, batch_dims, *, use_magma):
a, jpvt = batched_args
b_a, b_jpvt = batch_dims
a = batching.moveaxis(a, b_a, 0)
jpvt = batching.moveaxis(jpvt, b_jpvt, 0)
- return geqp3(a, jpvt), (0, 0, 0)
+ return geqp3(a, jpvt, use_magma=use_magma), (0, 0, 0)
-def _geqp3_cpu_lowering(ctx, a, jpvt):
+def _geqp3_cpu_lowering(ctx, a, jpvt, *, use_magma):
+ del use_magma
a_aval, _ = ctx.avals_in
target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype)
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
return rule(ctx, a, jpvt)
+def _geqp3_gpu_lowering(target_name_prefix, ctx, a, jpvt, *, use_magma):
+ gpu_solver.initialize_hybrid_kernels()
+ magma = config.gpu_use_magma.value
+ target_name = f"{target_name_prefix}hybrid_geqp3"
+ if use_magma is not None:
+ magma = "on" if use_magma else "off"
+ rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
+ return rule(ctx, a, jpvt, magma=magma)
+
geqp3_p = Primitive('geqp3')
geqp3_p.multiple_results = True
geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p))
geqp3_p.def_abstract_eval(_geqp3_abstract_eval)
batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule
mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu")
+mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'cu'), platform="cuda")
+mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'hip'), platform="rocm")
# householder_product: product of elementary Householder reflectors
@@ -1988,12 +2021,13 @@ mlir.register_lowering(
platform='rocm')
-def _qr_impl(operand, *, pivoting, full_matrices):
+def _qr_impl(operand, *, pivoting, full_matrices, use_magma):
q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting,
- full_matrices=full_matrices)
+ full_matrices=full_matrices, use_magma=use_magma)
return (q, r, p[0]) if pivoting else (q, r)
-def _qr_abstract_eval(operand, *, pivoting, full_matrices):
+def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
+ del use_magma
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
@@ -2018,11 +2052,11 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
q, r, p = operand, operand, operand
return (q, r, p) if pivoting else (q, r)
-def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
+def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma):
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
x, = primals
dx, = tangents
- q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False)
+ q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False, use_magma=use_magma)
*_, m, n = x.shape
if m < n or (full_matrices and m != n):
raise NotImplementedError(
@@ -2043,14 +2077,16 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
return (q, r, p[0]), (dq, dr, dp)
return (q, r), (dq, dr)
-def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices):
+def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices,
+ use_magma):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
out_axes = (0, 0, 0) if pivoting else (0, 0)
- return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes
+ return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
+ use_magma=use_magma), out_axes
-def _qr_lowering(a, *, pivoting, full_matrices):
+def _qr_lowering(a, *, pivoting, full_matrices, use_magma):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
k = m if full_matrices else core.min_dim(m, n)
@@ -2065,7 +2101,7 @@ def _qr_lowering(a, *, pivoting, full_matrices):
if pivoting:
jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
- r, p, taus = geqp3(a, jpvt)
+ r, p, taus = geqp3(a, jpvt, use_magma=use_magma)
p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1.
else:
r, taus = geqrf(a)
diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py
index 0d7b815b1..d40608027 100644
--- a/jax/_src/scipy/linalg.py
+++ b/jax/_src/scipy/linalg.py
@@ -953,7 +953,9 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "
with ``K = min(M, N)``.
Notes:
- - At present, pivoting is only implemented on CPU backends.
+ - At present, pivoting is only implemented on the CPU and GPU backends. For further
+ details about the GPU implementation, see the documentation for
+ :func:`jax.lax.linalg.qr`.
See also:
- :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API
diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h
index 47505020f..5c6d80093 100644
--- a/jaxlib/ffi_helpers.h
+++ b/jaxlib/ffi_helpers.h
@@ -184,6 +184,19 @@ auto AllocateScratchMemory(std::size_t size)
return std::unique_ptr(new ValueType[size]);
}
+template
+inline absl::StatusOr AllocateWorkspace(
+ ::xla::ffi::ScratchAllocator& scratch, int64_t size,
+ std::string_view name) {
+ auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
+ if (!maybe_workspace.has_value()) {
+ return absl::Status(
+ absl::StatusCode::kResourceExhausted,
+ absl::StrFormat("Unable to allocate workspace for %s", name));
+ }
+ return static_cast(maybe_workspace.value());
+}
+
} // namespace jax
#endif // JAXLIB_FFI_HELPERS_H_
diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc
index af0f2575d..94975a5b9 100644
--- a/jaxlib/gpu/hybrid.cc
+++ b/jaxlib/gpu/hybrid.cc
@@ -47,6 +47,10 @@ void GetLapackKernelsFromScipy() {
lapack_ptr("cgeev"));
AssignKernelFn>(
lapack_ptr("zgeev"));
+ AssignKernelFn>(lapack_ptr("sgeqp3"));
+ AssignKernelFn>(lapack_ptr("dgeqp3"));
+ AssignKernelFn>(lapack_ptr("cgeqp3"));
+ AssignKernelFn>(lapack_ptr("zgeqp3"));
});
}
@@ -57,6 +61,7 @@ NB_MODULE(_hybrid, m) {
nb::dict dict;
dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal);
dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp);
+ dict[JAX_GPU_PREFIX "hybrid_geqp3"] = EncapsulateFfiHandler(kGeqp3);
return dict;
});
}
diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc
index 1ce2e547b..8caa0e1d7 100644
--- a/jaxlib/gpu/hybrid_kernels.cc
+++ b/jaxlib/gpu/hybrid_kernels.cc
@@ -103,6 +103,30 @@ template <>
struct MagmaGeev {
static constexpr char name[] = "magma_zgeev";
};
+template
+struct MagmaGeqp3 {
+ static_assert(always_false::value, "unsupported data type");
+};
+template <>
+struct MagmaGeqp3 {
+ static constexpr char name[] = "magma_sgeqp3_gpu";
+ static constexpr char block_size_name[] = "magma_get_sgeqp3_nb";
+};
+template <>
+struct MagmaGeqp3 {
+ static constexpr char name[] = "magma_dgeqp3_gpu";
+ static constexpr char block_size_name[] = "magma_get_dgeqp3_nb";
+};
+template <>
+struct MagmaGeqp3 {
+ static constexpr char name[] = "magma_cgeqp3_gpu";
+ static constexpr char block_size_name[] = "magma_get_cgeqp3_nb";
+};
+template <>
+struct MagmaGeqp3 {
+ static constexpr char name[] = "magma_zgeqp3_gpu";
+ static constexpr char block_size_name[] = "magma_get_zgeqp3_nb";
+};
MagmaLookup::~MagmaLookup() {
if (initialized_) {
@@ -205,6 +229,245 @@ absl::StatusOr FindMagmaSymbol(const char name[]) {
return lookup.Find(name);
}
+// Column Pivoting QR Factorization
+
+// magma geqp3_gpu
+
+template
+class PivotingQrFactorizationHost {
+ using RealType = ffi::NativeType;
+ using ValueType = ffi::NativeType;
+
+ public:
+ explicit PivotingQrFactorizationHost() = default;
+ PivotingQrFactorizationHost(PivotingQrFactorizationHost&&) = default;
+
+ ffi::Error compute(int64_t batch, int64_t rows, int64_t cols,
+ gpuStream_t stream, ffi::ScratchAllocator& scratch,
+ ffi::AnyBuffer x, ffi::AnyBuffer jpvt,
+ ffi::Result x_out,
+ ffi::Result jpvt_out,
+ ffi::Result tau) {
+ FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows));
+ FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols));
+ auto min_dim = std::min(m, n);
+
+ FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n));
+ auto work = AllocateScratchMemory(lwork);
+
+ constexpr bool is_complex_dtype = ffi::IsComplexType();
+ std::unique_ptr rwork;
+ if constexpr (is_complex_dtype) {
+ rwork = AllocateScratchMemory(2 * n);
+ }
+
+ auto x_host = HostBuffer(x.element_count());
+ FFI_RETURN_IF_ERROR_STATUS(
+ x_host.CopyFromDevice(stream, x.typed_data()));
+ auto jpvt_host = HostBuffer(jpvt.element_count());
+ FFI_RETURN_IF_ERROR_STATUS(
+ jpvt_host.CopyFromDevice(stream, jpvt.typed_data()));
+ auto tau_host = HostBuffer(batch * min_dim);
+ auto info_host = HostBuffer(batch);
+
+ for (int64_t i = 0; i < batch; ++i) {
+ if constexpr (is_complex_dtype) {
+ PivotingQrFactorization::fn(
+ &m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n,
+ tau_host.get() + i * min_dim, work.get(), &lwork, rwork.get(),
+ info_host.get() + i);
+ } else {
+ PivotingQrFactorization::fn(
+ &m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n,
+ tau_host.get() + i * min_dim, work.get(), &lwork,
+ info_host.get() + i);
+ }
+ }
+
+ FFI_RETURN_IF_ERROR_STATUS(
+ x_host.CopyToDevice(stream, x_out->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(
+ jpvt_host.CopyToDevice(stream, jpvt_out->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(
+ tau_host.CopyToDevice(stream, tau->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+ return ffi::Error::Success();
+ }
+
+ private:
+ absl::StatusOr lwork(int m, int n) {
+ int64_t lwork = PivotingQrFactorization::GetWorkspaceSize(m, n);
+ return MaybeCastNoOverflow(lwork);
+ }
+};
+
+template
+class PivotingQrFactorizationMagma {
+ using RealType = ffi::NativeType;
+ using ValueType = ffi::NativeType;
+ using Fn = std::conditional_t<
+ ffi::IsComplexType(),
+ int(int m, int n, ValueType* dA, int ldda, int* jpvt, ValueType* tau,
+ ValueType* dwork, int lwork, RealType* rwork, int* info),
+ int(int m, int n, RealType* dA, int ldda, int* jpvt, RealType* tau,
+ RealType* dwork, int lwork, int* info)>;
+ using BlockSizeFn = int(int m, int n);
+
+ public:
+ explicit PivotingQrFactorizationMagma() = default;
+ PivotingQrFactorizationMagma(PivotingQrFactorizationMagma&&) = default;
+
+ ffi::Error compute(int64_t batch, int64_t rows, int64_t cols,
+ gpuStream_t stream, ffi::ScratchAllocator& scratch,
+ ffi::AnyBuffer x, ffi::AnyBuffer jpvt,
+ ffi::Result x_out,
+ ffi::Result jpvt_out,
+ ffi::Result tau) {
+ FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows));
+ FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols));
+ auto min_dim = std::min(m, n);
+
+ FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n));
+ FFI_ASSIGN_OR_RETURN(auto work,
+ AllocateWorkspace(scratch, lwork, "geqp3"));
+
+ constexpr bool is_complex_dtype = ffi::IsComplexType();
+ RealType* rwork;
+ if constexpr (is_complex_dtype) {
+ FFI_ASSIGN_OR_RETURN(
+ rwork, AllocateWorkspace(scratch, 2 * n, "geqp3"));
+ }
+
+ auto x_data = x.typed_data();
+ auto x_out_data = x_out->typed_data();
+ auto tau_data = tau->typed_data();
+ if (x_data != x_out_data) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ JAX_AS_STATUS(gpuMemcpyAsync(x_out_data, x_data, x.size_bytes(),
+ gpuMemcpyDeviceToDevice, stream)));
+ }
+ auto jpvt_host = HostBuffer(jpvt.element_count());
+ FFI_RETURN_IF_ERROR_STATUS(
+ jpvt_host.CopyFromDevice(stream, jpvt.typed_data()));
+ auto info_host = HostBuffer(batch);
+
+ // TODO: do we need to wrap with synchronise due to non-stream safety.
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+ for (int64_t i = 0; i < batch; ++i) {
+ if constexpr (is_complex_dtype) {
+ fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n,
+ tau_data + i * min_dim, work, lwork, rwork, info_host.get() + i);
+ } else {
+ fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n,
+ tau_data + i * min_dim, work, lwork, info_host.get() + i);
+ }
+ }
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+ FFI_RETURN_IF_ERROR_STATUS(
+ jpvt_host.CopyToDevice(stream, jpvt_out->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+ return ffi::Error::Success();
+ }
+
+ private:
+ Fn* fn_ = nullptr;
+ BlockSizeFn* block_size_fn_ = nullptr;
+
+ absl::StatusOr lwork(int m, int n) {
+ // `{c,d,s,z}_geqp3_gpu` do not support a workspace query, but we can still
+ // assign the symbol here.
+ auto maybe_ptr = FindMagmaSymbol(MagmaGeqp3::name);
+ if (!maybe_ptr.ok()) return maybe_ptr.status();
+ fn_ = reinterpret_cast(*maybe_ptr);
+
+ auto block_size_maybe_ptr =
+ FindMagmaSymbol(MagmaGeqp3::block_size_name);
+ if (!block_size_maybe_ptr.ok()) return block_size_maybe_ptr.status();
+ block_size_fn_ = reinterpret_cast(*block_size_maybe_ptr);
+ int optimal_block_size = block_size_fn_(m, n);
+ if constexpr (ffi::IsComplexType()) {
+ return (n + 1) * optimal_block_size;
+ }
+ return (n + 1) * optimal_block_size + 2 * n;
+ }
+};
+
+ffi::Error PivotingQrFactorizationDispatch(
+ gpuStream_t stream, ffi::ScratchAllocator scratch, std::string_view magma,
+ ffi::AnyBuffer x, ffi::AnyBuffer jpvt, ffi::Result x_out,
+ ffi::Result jpvt_out, ffi::Result tau) {
+ auto dataType = x.element_type();
+ if (dataType != x_out->element_type() || dataType != tau->element_type()) {
+ return ffi::Error::InvalidArgument(
+ "The buffers 'x', 'x_out' and 'tau' must have the same element type.");
+ }
+ FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
+ SplitBatch2D(x.dimensions()));
+ FFI_RETURN_IF_ERROR(
+ CheckShape(jpvt.dimensions(), {batch, cols}, "jpvt", "geqp3"));
+ FFI_RETURN_IF_ERROR(
+ CheckShape(x_out->dimensions(), {batch, rows, cols}, "x_out", "geqp3"));
+ FFI_RETURN_IF_ERROR(
+ CheckShape(jpvt_out->dimensions(), {batch, cols}, "jpvt_out", "geqp3"));
+ FFI_RETURN_IF_ERROR(CheckShape(
+ tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqp3"));
+
+ bool use_magma = magma == "on";
+ if (magma == "auto" && cols >= 2048) {
+ use_magma = FindMagmaSymbol("magma_init").ok();
+ }
+
+ switch (dataType) {
+ case ffi::F32:
+ if (use_magma) {
+ return PivotingQrFactorizationMagma().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ } else {
+ return PivotingQrFactorizationHost().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ }
+ case ffi::F64:
+ if (use_magma) {
+ return PivotingQrFactorizationMagma().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ } else {
+ return PivotingQrFactorizationHost().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ }
+ case ffi::C64:
+ if (use_magma) {
+ return PivotingQrFactorizationMagma().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ } else {
+ return PivotingQrFactorizationHost().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ }
+ case ffi::C128:
+ if (use_magma) {
+ return PivotingQrFactorizationMagma().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ } else {
+ return PivotingQrFactorizationHost().compute(
+ batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
+ }
+ default:
+ return ffi::Error::InvalidArgument(absl::StrFormat(
+ "Unsupported dtype %s in geqp3", absl::FormatStreamed(dataType)));
+ }
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(kGeqp3, PivotingQrFactorizationDispatch,
+ ffi::Ffi::Bind()
+ .Ctx>()
+ .Ctx()
+ .Attr("magma")
+ .Arg() // x
+ .Arg() // jpvt
+ .Ret() // x_out
+ .Ret() // jpvt_out
+ .Ret() // tau
+);
+
// Real-valued eigendecomposition
template
diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h
index 2890837a2..ba5510bb8 100644
--- a/jaxlib/gpu/hybrid_kernels.h
+++ b/jaxlib/gpu/hybrid_kernels.h
@@ -48,6 +48,7 @@ class MagmaLookup {
XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal);
XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(kGeqp3);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc
index 1d1701a49..79b5dff6c 100644
--- a/jaxlib/gpu/solver_kernels_ffi.cc
+++ b/jaxlib/gpu/solver_kernels_ffi.cc
@@ -51,19 +51,6 @@ namespace JAX_GPU_NAMESPACE {
namespace ffi = ::xla::ffi;
-template
-inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch,
- int64_t size,
- std::string_view name) {
- auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
- if (!maybe_workspace.has_value()) {
- return absl::Status(
- absl::StatusCode::kResourceExhausted,
- absl::StrFormat("Unable to allocate workspace for %s", name));
- }
- return static_cast(maybe_workspace.value());
-}
-
#if JAX_GPU_HAVE_64_BIT
// Map an FFI buffer element type to the appropriate GPU solver type.
diff --git a/tests/linalg_test.py b/tests/linalg_test.py
index 3027b36ff..528bec39a 100644
--- a/tests/linalg_test.py
+++ b/tests/linalg_test.py
@@ -1683,10 +1683,13 @@ class ScipyLinalgTest(jtu.JaxTestCase):
mode=["full", "r", "economic"],
pivoting=[False, True]
)
+ @jax.default_matmul_precision("float32")
def testScipyQrModes(self, shape, dtype, mode, pivoting):
- is_not_cpu_test_device = not jtu.test_device_matches(["cpu"])
- if pivoting and is_not_cpu_test_device:
- self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38")
+ if pivoting:
+ if not jtu.test_device_matches(["cpu", "gpu"]):
+ self.skipTest("Pivoting is only supported on CPU and GPU.")
+ if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0):
+ self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0")
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)
sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting)
@@ -1700,7 +1703,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def qr_and_mul(a):
q, r, *p = jsp_func(a)
# To express the identity function we must "undo" the pivoting of `q @ r`.
- inverted_pivots = p[0][p[0]]
+ inverted_pivots = jnp.argsort(p[0])
return (q @ r)[:, inverted_pivots]
m, n = shape
diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py
index 27ab58aae..2400672d1 100644
--- a/tests/magma_linalg_test.py
+++ b/tests/magma_linalg_test.py
@@ -114,5 +114,34 @@ class MagmaLinalgTest(jtu.JaxTestCase):
hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text()
self.assertIn('magma = "on"', hlo)
+ @jtu.sample_product(
+ shape=[(3, 4), (3, 3), (4, 3), (4, 3)],
+ dtype=float_types + complex_types,
+ )
+ @jtu.run_on_devices("gpu")
+ def testPivotedQrFactorization(self, shape, dtype):
+ if jtu.jaxlib_version() <= (0, 5, 0):
+ self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
+ if not gpu_solver.has_magma():
+ self.skipTest("MAGMA is not installed or can't be loaded.")
+ rng = jtu.rand_default(self.rng())
+
+ lax_func = partial(lax_linalg.qr, full_matrices=True, pivoting=True, use_magma=True)
+ sp_func = partial(jax.scipy.linalg.qr, mode="full", pivoting=True)
+ args_maker = lambda: [rng(shape, dtype)]
+ self._CheckAgainstNumpy(sp_func, lax_func, args_maker, rtol=1E-5, atol=1E-5)
+ self._CompileAndCheck(lax_func, args_maker)
+
+ def testPivotedQrFactorizationMagmaConfig(self):
+ if jtu.jaxlib_version() <= (0, 5, 0):
+ self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
+ if not gpu_solver.has_magma():
+ self.skipTest("MAGMA is not installed or can't be loaded.")
+ rng = jtu.rand_default(self.rng())
+ a = rng((5, 5), np.float32)
+ with config.gpu_use_magma("on"):
+ hlo = jax.jit(partial(lax_linalg.qr, pivoting=True, use_magma=True)).lower(a).as_text()
+ self.assertIn('magma = "on"', hlo)
+
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())