From b1b56ea0b074dfd18789b9a5313316b88fc2a32c Mon Sep 17 00:00:00 2001 From: tttc3 <97948946+tttc3@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:55:53 +0000 Subject: [PATCH] Enable pivoted QR on GPU via MAGMA. Originally noted in #20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA. --- CHANGELOG.md | 3 + jax/_src/lax/linalg.py | 74 ++++++--- jax/_src/scipy/linalg.py | 4 +- jaxlib/ffi_helpers.h | 13 ++ jaxlib/gpu/hybrid.cc | 5 + jaxlib/gpu/hybrid_kernels.cc | 263 +++++++++++++++++++++++++++++++ jaxlib/gpu/hybrid_kernels.h | 1 + jaxlib/gpu/solver_kernels_ffi.cc | 13 -- tests/linalg_test.py | 11 +- tests/magma_linalg_test.py | 29 ++++ 10 files changed, 379 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69c88f00f..b29ad3f1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`, {func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`, {func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`. + * {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support + column-pivoting on CPU and GPU. See {jax-issue}`#20282` and + {jax-issue}`#25955` for more details. * Changes * `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0653cad7a..13f8885c9 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -311,20 +311,22 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]: @overload def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True, - ) -> tuple[Array, Array]: + use_magma: bool | None = None) -> tuple[Array, Array]: ... @overload def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True, - ) -> tuple[Array, Array, Array]: + use_magma: bool | None = None) -> tuple[Array, Array, Array]: ... @overload def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True, + use_magma: bool | None = None ) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True, + use_magma: bool | None = None ) -> tuple[Array, Array] | tuple[Array, Array, Array]: """QR decomposition. @@ -341,9 +343,14 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True, pivoting: Allows the QR decomposition to be rank-revealing. If ``True``, compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P`` is chosen such that the diagonal of ``R`` is non-increasing. Currently - supported on CPU backends only. + supported on CPU and GPU backends only. full_matrices: Determines if full or reduced matrices are returned; see below. + use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the + pivoted `qr` factorization is computed using MAGMA. If ``False``, the + computation is done using LAPACK on the host CPU. If ``None`` (default), + the behavior is controlled by the ``jax_use_magma`` flag. This argument is + only used on GPU. Returns: A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``. @@ -357,8 +364,16 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True, ``full_matrices=False``. Array ``p`` is an index vector with shape [..., n] + + Notes: + - `MAGMA `_ support is experimental - see + :func:`jax.lax.linalg.eig` for further assumptions and limitations. + - If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will + be used if the library can be found, and the input matrix is sufficiently + large (has at least 2048 columns). """ - q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices) + q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices, + use_magma=use_magma) if pivoting: return q, r, p[0] return q, r @@ -1854,22 +1869,28 @@ mlir.register_lowering( platform='rocm') -def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]: +def geqp3(a: ArrayLike, jpvt: ArrayLike, *, + use_magma: bool | None = None) -> tuple[Array, Array, Array]: """Computes the column-pivoted QR decomposition of a matrix. Args: a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type. jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type, + use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the + `geqp3` is computed using MAGMA. If ``False``, the computation is done using + LAPACK on to the host CPU. If ``None`` (default), the behavior is controlled + by the ``jax_use_magma`` flag. This argument is only used on GPU. Returns: A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``, ``q`` is represented in the lower triangle of ``a`` and in ``taus`` as elementary Householder reflectors, and ``jpvt`` is the column-pivot indices such that ``a[:, jpvt] = q @ r``. """ - a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt) + a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma) return a_out, jpvt_out, taus -def _geqp3_abstract_eval(a, jpvt): +def _geqp3_abstract_eval(a, jpvt, *, use_magma): + del use_magma if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray): raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: " f"{a.aval}, {jpvt.aval}") @@ -1882,25 +1903,37 @@ def _geqp3_abstract_eval(a, jpvt): taus = a.update(shape=(*batch_dims, core.min_dim(m, n))) return a, jpvt, taus -def _geqp3_batching_rule(batched_args, batch_dims): +def _geqp3_batching_rule(batched_args, batch_dims, *, use_magma): a, jpvt = batched_args b_a, b_jpvt = batch_dims a = batching.moveaxis(a, b_a, 0) jpvt = batching.moveaxis(jpvt, b_jpvt, 0) - return geqp3(a, jpvt), (0, 0, 0) + return geqp3(a, jpvt, use_magma=use_magma), (0, 0, 0) -def _geqp3_cpu_lowering(ctx, a, jpvt): +def _geqp3_cpu_lowering(ctx, a, jpvt, *, use_magma): + del use_magma a_aval, _ = ctx.avals_in target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype) rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1}) return rule(ctx, a, jpvt) +def _geqp3_gpu_lowering(target_name_prefix, ctx, a, jpvt, *, use_magma): + gpu_solver.initialize_hybrid_kernels() + magma = config.gpu_use_magma.value + target_name = f"{target_name_prefix}hybrid_geqp3" + if use_magma is not None: + magma = "on" if use_magma else "off" + rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1}) + return rule(ctx, a, jpvt, magma=magma) + geqp3_p = Primitive('geqp3') geqp3_p.multiple_results = True geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p)) geqp3_p.def_abstract_eval(_geqp3_abstract_eval) batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu") +mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'cu'), platform="cuda") +mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'hip'), platform="rocm") # householder_product: product of elementary Householder reflectors @@ -1988,12 +2021,13 @@ mlir.register_lowering( platform='rocm') -def _qr_impl(operand, *, pivoting, full_matrices): +def _qr_impl(operand, *, pivoting, full_matrices, use_magma): q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting, - full_matrices=full_matrices) + full_matrices=full_matrices, use_magma=use_magma) return (q, r, p[0]) if pivoting else (q, r) -def _qr_abstract_eval(operand, *, pivoting, full_matrices): +def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma): + del use_magma if isinstance(operand, ShapedArray): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") @@ -2018,11 +2052,11 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices): q, r, p = operand, operand, operand return (q, r, p) if pivoting else (q, r) -def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices): +def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma): # See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation. x, = primals dx, = tangents - q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False) + q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False, use_magma=use_magma) *_, m, n = x.shape if m < n or (full_matrices and m != n): raise NotImplementedError( @@ -2043,14 +2077,16 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices): return (q, r, p[0]), (dq, dr, dp) return (q, r), (dq, dr) -def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices): +def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices, + use_magma): x, = batched_args bd, = batch_dims x = batching.moveaxis(x, bd, 0) out_axes = (0, 0, 0) if pivoting else (0, 0) - return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes + return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices, + use_magma=use_magma), out_axes -def _qr_lowering(a, *, pivoting, full_matrices): +def _qr_lowering(a, *, pivoting, full_matrices, use_magma): *batch_dims, m, n = a.shape if m == 0 or n == 0: k = m if full_matrices else core.min_dim(m, n) @@ -2065,7 +2101,7 @@ def _qr_lowering(a, *, pivoting, full_matrices): if pivoting: jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32)) - r, p, taus = geqp3(a, jpvt) + r, p, taus = geqp3(a, jpvt, use_magma=use_magma) p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1. else: r, taus = geqrf(a) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 0d7b815b1..d40608027 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -953,7 +953,9 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = " with ``K = min(M, N)``. Notes: - - At present, pivoting is only implemented on CPU backends. + - At present, pivoting is only implemented on the CPU and GPU backends. For further + details about the GPU implementation, see the documentation for + :func:`jax.lax.linalg.qr`. See also: - :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 47505020f..5c6d80093 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -184,6 +184,19 @@ auto AllocateScratchMemory(std::size_t size) return std::unique_ptr(new ValueType[size]); } +template +inline absl::StatusOr AllocateWorkspace( + ::xla::ffi::ScratchAllocator& scratch, int64_t size, + std::string_view name) { + auto maybe_workspace = scratch.Allocate(sizeof(T) * size); + if (!maybe_workspace.has_value()) { + return absl::Status( + absl::StatusCode::kResourceExhausted, + absl::StrFormat("Unable to allocate workspace for %s", name)); + } + return static_cast(maybe_workspace.value()); +} + } // namespace jax #endif // JAXLIB_FFI_HELPERS_H_ diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc index af0f2575d..94975a5b9 100644 --- a/jaxlib/gpu/hybrid.cc +++ b/jaxlib/gpu/hybrid.cc @@ -47,6 +47,10 @@ void GetLapackKernelsFromScipy() { lapack_ptr("cgeev")); AssignKernelFn>( lapack_ptr("zgeev")); + AssignKernelFn>(lapack_ptr("sgeqp3")); + AssignKernelFn>(lapack_ptr("dgeqp3")); + AssignKernelFn>(lapack_ptr("cgeqp3")); + AssignKernelFn>(lapack_ptr("zgeqp3")); }); } @@ -57,6 +61,7 @@ NB_MODULE(_hybrid, m) { nb::dict dict; dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal); dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp); + dict[JAX_GPU_PREFIX "hybrid_geqp3"] = EncapsulateFfiHandler(kGeqp3); return dict; }); } diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc index 1ce2e547b..8caa0e1d7 100644 --- a/jaxlib/gpu/hybrid_kernels.cc +++ b/jaxlib/gpu/hybrid_kernels.cc @@ -103,6 +103,30 @@ template <> struct MagmaGeev { static constexpr char name[] = "magma_zgeev"; }; +template +struct MagmaGeqp3 { + static_assert(always_false::value, "unsupported data type"); +}; +template <> +struct MagmaGeqp3 { + static constexpr char name[] = "magma_sgeqp3_gpu"; + static constexpr char block_size_name[] = "magma_get_sgeqp3_nb"; +}; +template <> +struct MagmaGeqp3 { + static constexpr char name[] = "magma_dgeqp3_gpu"; + static constexpr char block_size_name[] = "magma_get_dgeqp3_nb"; +}; +template <> +struct MagmaGeqp3 { + static constexpr char name[] = "magma_cgeqp3_gpu"; + static constexpr char block_size_name[] = "magma_get_cgeqp3_nb"; +}; +template <> +struct MagmaGeqp3 { + static constexpr char name[] = "magma_zgeqp3_gpu"; + static constexpr char block_size_name[] = "magma_get_zgeqp3_nb"; +}; MagmaLookup::~MagmaLookup() { if (initialized_) { @@ -205,6 +229,245 @@ absl::StatusOr FindMagmaSymbol(const char name[]) { return lookup.Find(name); } +// Column Pivoting QR Factorization + +// magma geqp3_gpu + +template +class PivotingQrFactorizationHost { + using RealType = ffi::NativeType; + using ValueType = ffi::NativeType; + + public: + explicit PivotingQrFactorizationHost() = default; + PivotingQrFactorizationHost(PivotingQrFactorizationHost&&) = default; + + ffi::Error compute(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer x, ffi::AnyBuffer jpvt, + ffi::Result x_out, + ffi::Result jpvt_out, + ffi::Result tau) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + auto min_dim = std::min(m, n); + + FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n)); + auto work = AllocateScratchMemory(lwork); + + constexpr bool is_complex_dtype = ffi::IsComplexType(); + std::unique_ptr rwork; + if constexpr (is_complex_dtype) { + rwork = AllocateScratchMemory(2 * n); + } + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + auto jpvt_host = HostBuffer(jpvt.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + jpvt_host.CopyFromDevice(stream, jpvt.typed_data())); + auto tau_host = HostBuffer(batch * min_dim); + auto info_host = HostBuffer(batch); + + for (int64_t i = 0; i < batch; ++i) { + if constexpr (is_complex_dtype) { + PivotingQrFactorization::fn( + &m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n, + tau_host.get() + i * min_dim, work.get(), &lwork, rwork.get(), + info_host.get() + i); + } else { + PivotingQrFactorization::fn( + &m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n, + tau_host.get() + i * min_dim, work.get(), &lwork, + info_host.get() + i); + } + } + + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyToDevice(stream, x_out->typed_data())); + FFI_RETURN_IF_ERROR_STATUS( + jpvt_host.CopyToDevice(stream, jpvt_out->typed_data())); + FFI_RETURN_IF_ERROR_STATUS( + tau_host.CopyToDevice(stream, tau->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + return ffi::Error::Success(); + } + + private: + absl::StatusOr lwork(int m, int n) { + int64_t lwork = PivotingQrFactorization::GetWorkspaceSize(m, n); + return MaybeCastNoOverflow(lwork); + } +}; + +template +class PivotingQrFactorizationMagma { + using RealType = ffi::NativeType; + using ValueType = ffi::NativeType; + using Fn = std::conditional_t< + ffi::IsComplexType(), + int(int m, int n, ValueType* dA, int ldda, int* jpvt, ValueType* tau, + ValueType* dwork, int lwork, RealType* rwork, int* info), + int(int m, int n, RealType* dA, int ldda, int* jpvt, RealType* tau, + RealType* dwork, int lwork, int* info)>; + using BlockSizeFn = int(int m, int n); + + public: + explicit PivotingQrFactorizationMagma() = default; + PivotingQrFactorizationMagma(PivotingQrFactorizationMagma&&) = default; + + ffi::Error compute(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer x, ffi::AnyBuffer jpvt, + ffi::Result x_out, + ffi::Result jpvt_out, + ffi::Result tau) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + auto min_dim = std::min(m, n); + + FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n)); + FFI_ASSIGN_OR_RETURN(auto work, + AllocateWorkspace(scratch, lwork, "geqp3")); + + constexpr bool is_complex_dtype = ffi::IsComplexType(); + RealType* rwork; + if constexpr (is_complex_dtype) { + FFI_ASSIGN_OR_RETURN( + rwork, AllocateWorkspace(scratch, 2 * n, "geqp3")); + } + + auto x_data = x.typed_data(); + auto x_out_data = x_out->typed_data(); + auto tau_data = tau->typed_data(); + if (x_data != x_out_data) { + FFI_RETURN_IF_ERROR_STATUS( + JAX_AS_STATUS(gpuMemcpyAsync(x_out_data, x_data, x.size_bytes(), + gpuMemcpyDeviceToDevice, stream))); + } + auto jpvt_host = HostBuffer(jpvt.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + jpvt_host.CopyFromDevice(stream, jpvt.typed_data())); + auto info_host = HostBuffer(batch); + + // TODO: do we need to wrap with synchronise due to non-stream safety. + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + for (int64_t i = 0; i < batch; ++i) { + if constexpr (is_complex_dtype) { + fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n, + tau_data + i * min_dim, work, lwork, rwork, info_host.get() + i); + } else { + fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n, + tau_data + i * min_dim, work, lwork, info_host.get() + i); + } + } + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + FFI_RETURN_IF_ERROR_STATUS( + jpvt_host.CopyToDevice(stream, jpvt_out->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + return ffi::Error::Success(); + } + + private: + Fn* fn_ = nullptr; + BlockSizeFn* block_size_fn_ = nullptr; + + absl::StatusOr lwork(int m, int n) { + // `{c,d,s,z}_geqp3_gpu` do not support a workspace query, but we can still + // assign the symbol here. + auto maybe_ptr = FindMagmaSymbol(MagmaGeqp3::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + auto block_size_maybe_ptr = + FindMagmaSymbol(MagmaGeqp3::block_size_name); + if (!block_size_maybe_ptr.ok()) return block_size_maybe_ptr.status(); + block_size_fn_ = reinterpret_cast(*block_size_maybe_ptr); + int optimal_block_size = block_size_fn_(m, n); + if constexpr (ffi::IsComplexType()) { + return (n + 1) * optimal_block_size; + } + return (n + 1) * optimal_block_size + 2 * n; + } +}; + +ffi::Error PivotingQrFactorizationDispatch( + gpuStream_t stream, ffi::ScratchAllocator scratch, std::string_view magma, + ffi::AnyBuffer x, ffi::AnyBuffer jpvt, ffi::Result x_out, + ffi::Result jpvt_out, ffi::Result tau) { + auto dataType = x.element_type(); + if (dataType != x_out->element_type() || dataType != tau->element_type()) { + return ffi::Error::InvalidArgument( + "The buffers 'x', 'x_out' and 'tau' must have the same element type."); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(jpvt.dimensions(), {batch, cols}, "jpvt", "geqp3")); + FFI_RETURN_IF_ERROR( + CheckShape(x_out->dimensions(), {batch, rows, cols}, "x_out", "geqp3")); + FFI_RETURN_IF_ERROR( + CheckShape(jpvt_out->dimensions(), {batch, cols}, "jpvt_out", "geqp3")); + FFI_RETURN_IF_ERROR(CheckShape( + tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqp3")); + + bool use_magma = magma == "on"; + if (magma == "auto" && cols >= 2048) { + use_magma = FindMagmaSymbol("magma_init").ok(); + } + + switch (dataType) { + case ffi::F32: + if (use_magma) { + return PivotingQrFactorizationMagma().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } else { + return PivotingQrFactorizationHost().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } + case ffi::F64: + if (use_magma) { + return PivotingQrFactorizationMagma().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } else { + return PivotingQrFactorizationHost().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } + case ffi::C64: + if (use_magma) { + return PivotingQrFactorizationMagma().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } else { + return PivotingQrFactorizationHost().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } + case ffi::C128: + if (use_magma) { + return PivotingQrFactorizationMagma().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } else { + return PivotingQrFactorizationHost().compute( + batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in geqp3", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kGeqp3, PivotingQrFactorizationDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("magma") + .Arg() // x + .Arg() // jpvt + .Ret() // x_out + .Ret() // jpvt_out + .Ret() // tau +); + // Real-valued eigendecomposition template diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h index 2890837a2..ba5510bb8 100644 --- a/jaxlib/gpu/hybrid_kernels.h +++ b/jaxlib/gpu/hybrid_kernels.h @@ -48,6 +48,7 @@ class MagmaLookup { XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal); XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kGeqp3); } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 1d1701a49..79b5dff6c 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -51,19 +51,6 @@ namespace JAX_GPU_NAMESPACE { namespace ffi = ::xla::ffi; -template -inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, - int64_t size, - std::string_view name) { - auto maybe_workspace = scratch.Allocate(sizeof(T) * size); - if (!maybe_workspace.has_value()) { - return absl::Status( - absl::StatusCode::kResourceExhausted, - absl::StrFormat("Unable to allocate workspace for %s", name)); - } - return static_cast(maybe_workspace.value()); -} - #if JAX_GPU_HAVE_64_BIT // Map an FFI buffer element type to the appropriate GPU solver type. diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 3027b36ff..528bec39a 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -1683,10 +1683,13 @@ class ScipyLinalgTest(jtu.JaxTestCase): mode=["full", "r", "economic"], pivoting=[False, True] ) + @jax.default_matmul_precision("float32") def testScipyQrModes(self, shape, dtype, mode, pivoting): - is_not_cpu_test_device = not jtu.test_device_matches(["cpu"]) - if pivoting and is_not_cpu_test_device: - self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38") + if pivoting: + if not jtu.test_device_matches(["cpu", "gpu"]): + self.skipTest("Pivoting is only supported on CPU and GPU.") + if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0): + self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0") rng = jtu.rand_default(self.rng()) jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting) sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting) @@ -1700,7 +1703,7 @@ class ScipyLinalgTest(jtu.JaxTestCase): def qr_and_mul(a): q, r, *p = jsp_func(a) # To express the identity function we must "undo" the pivoting of `q @ r`. - inverted_pivots = p[0][p[0]] + inverted_pivots = jnp.argsort(p[0]) return (q @ r)[:, inverted_pivots] m, n = shape diff --git a/tests/magma_linalg_test.py b/tests/magma_linalg_test.py index 27ab58aae..2400672d1 100644 --- a/tests/magma_linalg_test.py +++ b/tests/magma_linalg_test.py @@ -114,5 +114,34 @@ class MagmaLinalgTest(jtu.JaxTestCase): hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text() self.assertIn('magma = "on"', hlo) + @jtu.sample_product( + shape=[(3, 4), (3, 3), (4, 3), (4, 3)], + dtype=float_types + complex_types, + ) + @jtu.run_on_devices("gpu") + def testPivotedQrFactorization(self, shape, dtype): + if jtu.jaxlib_version() <= (0, 5, 0): + self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + rng = jtu.rand_default(self.rng()) + + lax_func = partial(lax_linalg.qr, full_matrices=True, pivoting=True, use_magma=True) + sp_func = partial(jax.scipy.linalg.qr, mode="full", pivoting=True) + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(sp_func, lax_func, args_maker, rtol=1E-5, atol=1E-5) + self._CompileAndCheck(lax_func, args_maker) + + def testPivotedQrFactorizationMagmaConfig(self): + if jtu.jaxlib_version() <= (0, 5, 0): + self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0") + if not gpu_solver.has_magma(): + self.skipTest("MAGMA is not installed or can't be loaded.") + rng = jtu.rand_default(self.rng()) + a = rng((5, 5), np.float32) + with config.gpu_use_magma("on"): + hlo = jax.jit(partial(lax_linalg.qr, pivoting=True, use_magma=True)).lower(a).as_text() + self.assertIn('magma = "on"', hlo) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())