Enable pivoted QR on GPU via MAGMA.

Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
This commit is contained in:
tttc3 2025-01-10 15:55:53 +00:00
parent e14466a8fb
commit b1b56ea0b0
10 changed files with 379 additions and 37 deletions

View File

@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
{jax-issue}`#25955` for more details.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as

View File

@ -311,20 +311,22 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
@overload
def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True,
) -> tuple[Array, Array]:
use_magma: bool | None = None) -> tuple[Array, Array]:
...
@overload
def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True,
) -> tuple[Array, Array, Array]:
use_magma: bool | None = None) -> tuple[Array, Array, Array]:
...
@overload
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
use_magma: bool | None = None
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
...
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
use_magma: bool | None = None
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
"""QR decomposition.
@ -341,9 +343,14 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``,
compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P``
is chosen such that the diagonal of ``R`` is non-increasing. Currently
supported on CPU backends only.
supported on CPU and GPU backends only.
full_matrices: Determines if full or reduced matrices are returned; see
below.
use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
pivoted `qr` factorization is computed using MAGMA. If ``False``, the
computation is done using LAPACK on the host CPU. If ``None`` (default),
the behavior is controlled by the ``jax_use_magma`` flag. This argument is
only used on GPU.
Returns:
A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``.
@ -357,8 +364,16 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
``full_matrices=False``.
Array ``p`` is an index vector with shape [..., n]
Notes:
- `MAGMA <https://icl.utk.edu/magma/>`_ support is experimental - see
:func:`jax.lax.linalg.eig` for further assumptions and limitations.
- If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will
be used if the library can be found, and the input matrix is sufficiently
large (has at least 2048 columns).
"""
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
use_magma=use_magma)
if pivoting:
return q, r, p[0]
return q, r
@ -1854,22 +1869,28 @@ mlir.register_lowering(
platform='rocm')
def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]:
def geqp3(a: ArrayLike, jpvt: ArrayLike, *,
use_magma: bool | None = None) -> tuple[Array, Array, Array]:
"""Computes the column-pivoted QR decomposition of a matrix.
Args:
a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type.
jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type,
use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
`geqp3` is computed using MAGMA. If ``False``, the computation is done using
LAPACK on to the host CPU. If ``None`` (default), the behavior is controlled
by the ``jax_use_magma`` flag. This argument is only used on GPU.
Returns:
A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``,
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
such that ``a[:, jpvt] = q @ r``.
"""
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt)
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma)
return a_out, jpvt_out, taus
def _geqp3_abstract_eval(a, jpvt):
def _geqp3_abstract_eval(a, jpvt, *, use_magma):
del use_magma
if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray):
raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: "
f"{a.aval}, {jpvt.aval}")
@ -1882,25 +1903,37 @@ def _geqp3_abstract_eval(a, jpvt):
taus = a.update(shape=(*batch_dims, core.min_dim(m, n)))
return a, jpvt, taus
def _geqp3_batching_rule(batched_args, batch_dims):
def _geqp3_batching_rule(batched_args, batch_dims, *, use_magma):
a, jpvt = batched_args
b_a, b_jpvt = batch_dims
a = batching.moveaxis(a, b_a, 0)
jpvt = batching.moveaxis(jpvt, b_jpvt, 0)
return geqp3(a, jpvt), (0, 0, 0)
return geqp3(a, jpvt, use_magma=use_magma), (0, 0, 0)
def _geqp3_cpu_lowering(ctx, a, jpvt):
def _geqp3_cpu_lowering(ctx, a, jpvt, *, use_magma):
del use_magma
a_aval, _ = ctx.avals_in
target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype)
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
return rule(ctx, a, jpvt)
def _geqp3_gpu_lowering(target_name_prefix, ctx, a, jpvt, *, use_magma):
gpu_solver.initialize_hybrid_kernels()
magma = config.gpu_use_magma.value
target_name = f"{target_name_prefix}hybrid_geqp3"
if use_magma is not None:
magma = "on" if use_magma else "off"
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
return rule(ctx, a, jpvt, magma=magma)
geqp3_p = Primitive('geqp3')
geqp3_p.multiple_results = True
geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p))
geqp3_p.def_abstract_eval(_geqp3_abstract_eval)
batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule
mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu")
mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'cu'), platform="cuda")
mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'hip'), platform="rocm")
# householder_product: product of elementary Householder reflectors
@ -1988,12 +2021,13 @@ mlir.register_lowering(
platform='rocm')
def _qr_impl(operand, *, pivoting, full_matrices):
def _qr_impl(operand, *, pivoting, full_matrices, use_magma):
q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting,
full_matrices=full_matrices)
full_matrices=full_matrices, use_magma=use_magma)
return (q, r, p[0]) if pivoting else (q, r)
def _qr_abstract_eval(operand, *, pivoting, full_matrices):
def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
del use_magma
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
@ -2018,11 +2052,11 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
q, r, p = operand, operand, operand
return (q, r, p) if pivoting else (q, r)
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma):
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
x, = primals
dx, = tangents
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False, use_magma=use_magma)
*_, m, n = x.shape
if m < n or (full_matrices and m != n):
raise NotImplementedError(
@ -2043,14 +2077,16 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
return (q, r, p[0]), (dq, dr, dp)
return (q, r), (dq, dr)
def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices):
def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices,
use_magma):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
out_axes = (0, 0, 0) if pivoting else (0, 0)
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
use_magma=use_magma), out_axes
def _qr_lowering(a, *, pivoting, full_matrices):
def _qr_lowering(a, *, pivoting, full_matrices, use_magma):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
k = m if full_matrices else core.min_dim(m, n)
@ -2065,7 +2101,7 @@ def _qr_lowering(a, *, pivoting, full_matrices):
if pivoting:
jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
r, p, taus = geqp3(a, jpvt)
r, p, taus = geqp3(a, jpvt, use_magma=use_magma)
p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1.
else:
r, taus = geqrf(a)

View File

@ -953,7 +953,9 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "
with ``K = min(M, N)``.
Notes:
- At present, pivoting is only implemented on CPU backends.
- At present, pivoting is only implemented on the CPU and GPU backends. For further
details about the GPU implementation, see the documentation for
:func:`jax.lax.linalg.qr`.
See also:
- :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API

View File

@ -184,6 +184,19 @@ auto AllocateScratchMemory(std::size_t size)
return std::unique_ptr<ValueType[]>(new ValueType[size]);
}
template <typename T>
inline absl::StatusOr<T*> AllocateWorkspace(
::xla::ffi::ScratchAllocator& scratch, int64_t size,
std::string_view name) {
auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
if (!maybe_workspace.has_value()) {
return absl::Status(
absl::StatusCode::kResourceExhausted,
absl::StrFormat("Unable to allocate workspace for %s", name));
}
return static_cast<T*>(maybe_workspace.value());
}
} // namespace jax
#endif // JAXLIB_FFI_HELPERS_H_

View File

@ -47,6 +47,10 @@ void GetLapackKernelsFromScipy() {
lapack_ptr("cgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<ffi::C128>>(
lapack_ptr("zgeev"));
AssignKernelFn<PivotingQrFactorization<ffi::F32>>(lapack_ptr("sgeqp3"));
AssignKernelFn<PivotingQrFactorization<ffi::F64>>(lapack_ptr("dgeqp3"));
AssignKernelFn<PivotingQrFactorization<ffi::C64>>(lapack_ptr("cgeqp3"));
AssignKernelFn<PivotingQrFactorization<ffi::C128>>(lapack_ptr("zgeqp3"));
});
}
@ -57,6 +61,7 @@ NB_MODULE(_hybrid, m) {
nb::dict dict;
dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal);
dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp);
dict[JAX_GPU_PREFIX "hybrid_geqp3"] = EncapsulateFfiHandler(kGeqp3);
return dict;
});
}

View File

@ -103,6 +103,30 @@ template <>
struct MagmaGeev<ffi::C128> {
static constexpr char name[] = "magma_zgeev";
};
template <ffi::DataType DataType>
struct MagmaGeqp3 {
static_assert(always_false<DataType>::value, "unsupported data type");
};
template <>
struct MagmaGeqp3<ffi::F32> {
static constexpr char name[] = "magma_sgeqp3_gpu";
static constexpr char block_size_name[] = "magma_get_sgeqp3_nb";
};
template <>
struct MagmaGeqp3<ffi::F64> {
static constexpr char name[] = "magma_dgeqp3_gpu";
static constexpr char block_size_name[] = "magma_get_dgeqp3_nb";
};
template <>
struct MagmaGeqp3<ffi::C64> {
static constexpr char name[] = "magma_cgeqp3_gpu";
static constexpr char block_size_name[] = "magma_get_cgeqp3_nb";
};
template <>
struct MagmaGeqp3<ffi::C128> {
static constexpr char name[] = "magma_zgeqp3_gpu";
static constexpr char block_size_name[] = "magma_get_zgeqp3_nb";
};
MagmaLookup::~MagmaLookup() {
if (initialized_) {
@ -205,6 +229,245 @@ absl::StatusOr<void*> FindMagmaSymbol(const char name[]) {
return lookup.Find(name);
}
// Column Pivoting QR Factorization
// magma geqp3_gpu
template <ffi::DataType DataType>
class PivotingQrFactorizationHost {
using RealType = ffi::NativeType<ffi::ToReal(DataType)>;
using ValueType = ffi::NativeType<DataType>;
public:
explicit PivotingQrFactorizationHost() = default;
PivotingQrFactorizationHost(PivotingQrFactorizationHost&&) = default;
ffi::Error compute(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer x, ffi::AnyBuffer jpvt,
ffi::Result<ffi::AnyBuffer> x_out,
ffi::Result<ffi::AnyBuffer> jpvt_out,
ffi::Result<ffi::AnyBuffer> tau) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
auto min_dim = std::min(m, n);
FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n));
auto work = AllocateScratchMemory<DataType>(lwork);
constexpr bool is_complex_dtype = ffi::IsComplexType<DataType>();
std::unique_ptr<RealType[]> rwork;
if constexpr (is_complex_dtype) {
rwork = AllocateScratchMemory<ffi::ToReal(DataType)>(2 * n);
}
auto x_host = HostBuffer<ValueType>(x.element_count());
FFI_RETURN_IF_ERROR_STATUS(
x_host.CopyFromDevice(stream, x.typed_data<ValueType>()));
auto jpvt_host = HostBuffer<int>(jpvt.element_count());
FFI_RETURN_IF_ERROR_STATUS(
jpvt_host.CopyFromDevice(stream, jpvt.typed_data<int>()));
auto tau_host = HostBuffer<ValueType>(batch * min_dim);
auto info_host = HostBuffer<int>(batch);
for (int64_t i = 0; i < batch; ++i) {
if constexpr (is_complex_dtype) {
PivotingQrFactorization<DataType>::fn(
&m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n,
tau_host.get() + i * min_dim, work.get(), &lwork, rwork.get(),
info_host.get() + i);
} else {
PivotingQrFactorization<DataType>::fn(
&m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n,
tau_host.get() + i * min_dim, work.get(), &lwork,
info_host.get() + i);
}
}
FFI_RETURN_IF_ERROR_STATUS(
x_host.CopyToDevice(stream, x_out->typed_data<ValueType>()));
FFI_RETURN_IF_ERROR_STATUS(
jpvt_host.CopyToDevice(stream, jpvt_out->typed_data<int>()));
FFI_RETURN_IF_ERROR_STATUS(
tau_host.CopyToDevice(stream, tau->typed_data<ValueType>()));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
return ffi::Error::Success();
}
private:
absl::StatusOr<int> lwork(int m, int n) {
int64_t lwork = PivotingQrFactorization<DataType>::GetWorkspaceSize(m, n);
return MaybeCastNoOverflow<int>(lwork);
}
};
template <ffi::DataType DataType>
class PivotingQrFactorizationMagma {
using RealType = ffi::NativeType<ffi::ToReal(DataType)>;
using ValueType = ffi::NativeType<DataType>;
using Fn = std::conditional_t<
ffi::IsComplexType<DataType>(),
int(int m, int n, ValueType* dA, int ldda, int* jpvt, ValueType* tau,
ValueType* dwork, int lwork, RealType* rwork, int* info),
int(int m, int n, RealType* dA, int ldda, int* jpvt, RealType* tau,
RealType* dwork, int lwork, int* info)>;
using BlockSizeFn = int(int m, int n);
public:
explicit PivotingQrFactorizationMagma() = default;
PivotingQrFactorizationMagma(PivotingQrFactorizationMagma&&) = default;
ffi::Error compute(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer x, ffi::AnyBuffer jpvt,
ffi::Result<ffi::AnyBuffer> x_out,
ffi::Result<ffi::AnyBuffer> jpvt_out,
ffi::Result<ffi::AnyBuffer> tau) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
auto min_dim = std::min(m, n);
FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n));
FFI_ASSIGN_OR_RETURN(auto work,
AllocateWorkspace<ValueType>(scratch, lwork, "geqp3"));
constexpr bool is_complex_dtype = ffi::IsComplexType<DataType>();
RealType* rwork;
if constexpr (is_complex_dtype) {
FFI_ASSIGN_OR_RETURN(
rwork, AllocateWorkspace<RealType>(scratch, 2 * n, "geqp3"));
}
auto x_data = x.typed_data<ValueType>();
auto x_out_data = x_out->typed_data<ValueType>();
auto tau_data = tau->typed_data<ValueType>();
if (x_data != x_out_data) {
FFI_RETURN_IF_ERROR_STATUS(
JAX_AS_STATUS(gpuMemcpyAsync(x_out_data, x_data, x.size_bytes(),
gpuMemcpyDeviceToDevice, stream)));
}
auto jpvt_host = HostBuffer<int>(jpvt.element_count());
FFI_RETURN_IF_ERROR_STATUS(
jpvt_host.CopyFromDevice(stream, jpvt.typed_data<int>()));
auto info_host = HostBuffer<int>(batch);
// TODO: do we need to wrap with synchronise due to non-stream safety.
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
for (int64_t i = 0; i < batch; ++i) {
if constexpr (is_complex_dtype) {
fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n,
tau_data + i * min_dim, work, lwork, rwork, info_host.get() + i);
} else {
fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n,
tau_data + i * min_dim, work, lwork, info_host.get() + i);
}
}
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
FFI_RETURN_IF_ERROR_STATUS(
jpvt_host.CopyToDevice(stream, jpvt_out->typed_data<int>()));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
return ffi::Error::Success();
}
private:
Fn* fn_ = nullptr;
BlockSizeFn* block_size_fn_ = nullptr;
absl::StatusOr<int> lwork(int m, int n) {
// `{c,d,s,z}_geqp3_gpu` do not support a workspace query, but we can still
// assign the symbol here.
auto maybe_ptr = FindMagmaSymbol(MagmaGeqp3<DataType>::name);
if (!maybe_ptr.ok()) return maybe_ptr.status();
fn_ = reinterpret_cast<Fn*>(*maybe_ptr);
auto block_size_maybe_ptr =
FindMagmaSymbol(MagmaGeqp3<DataType>::block_size_name);
if (!block_size_maybe_ptr.ok()) return block_size_maybe_ptr.status();
block_size_fn_ = reinterpret_cast<BlockSizeFn*>(*block_size_maybe_ptr);
int optimal_block_size = block_size_fn_(m, n);
if constexpr (ffi::IsComplexType<DataType>()) {
return (n + 1) * optimal_block_size;
}
return (n + 1) * optimal_block_size + 2 * n;
}
};
ffi::Error PivotingQrFactorizationDispatch(
gpuStream_t stream, ffi::ScratchAllocator scratch, std::string_view magma,
ffi::AnyBuffer x, ffi::AnyBuffer jpvt, ffi::Result<ffi::AnyBuffer> x_out,
ffi::Result<ffi::AnyBuffer> jpvt_out, ffi::Result<ffi::AnyBuffer> tau) {
auto dataType = x.element_type();
if (dataType != x_out->element_type() || dataType != tau->element_type()) {
return ffi::Error::InvalidArgument(
"The buffers 'x', 'x_out' and 'tau' must have the same element type.");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(x.dimensions()));
FFI_RETURN_IF_ERROR(
CheckShape(jpvt.dimensions(), {batch, cols}, "jpvt", "geqp3"));
FFI_RETURN_IF_ERROR(
CheckShape(x_out->dimensions(), {batch, rows, cols}, "x_out", "geqp3"));
FFI_RETURN_IF_ERROR(
CheckShape(jpvt_out->dimensions(), {batch, cols}, "jpvt_out", "geqp3"));
FFI_RETURN_IF_ERROR(CheckShape(
tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqp3"));
bool use_magma = magma == "on";
if (magma == "auto" && cols >= 2048) {
use_magma = FindMagmaSymbol("magma_init").ok();
}
switch (dataType) {
case ffi::F32:
if (use_magma) {
return PivotingQrFactorizationMagma<ffi::F32>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
} else {
return PivotingQrFactorizationHost<ffi::F32>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
}
case ffi::F64:
if (use_magma) {
return PivotingQrFactorizationMagma<ffi::F64>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
} else {
return PivotingQrFactorizationHost<ffi::F64>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
}
case ffi::C64:
if (use_magma) {
return PivotingQrFactorizationMagma<ffi::C64>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
} else {
return PivotingQrFactorizationHost<ffi::C64>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
}
case ffi::C128:
if (use_magma) {
return PivotingQrFactorizationMagma<ffi::C128>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
} else {
return PivotingQrFactorizationHost<ffi::C128>().compute(
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
}
default:
return ffi::Error::InvalidArgument(absl::StrFormat(
"Unsupported dtype %s in geqp3", absl::FormatStreamed(dataType)));
}
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(kGeqp3, PivotingQrFactorizationDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Attr<std::string_view>("magma")
.Arg<ffi::AnyBuffer>() // x
.Arg<ffi::AnyBuffer>() // jpvt
.Ret<ffi::AnyBuffer>() // x_out
.Ret<ffi::AnyBuffer>() // jpvt_out
.Ret<ffi::AnyBuffer>() // tau
);
// Real-valued eigendecomposition
template <ffi::DataType DataType>

View File

@ -48,6 +48,7 @@ class MagmaLookup {
XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal);
XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp);
XLA_FFI_DECLARE_HANDLER_SYMBOL(kGeqp3);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -51,19 +51,6 @@ namespace JAX_GPU_NAMESPACE {
namespace ffi = ::xla::ffi;
template <typename T>
inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
int64_t size,
std::string_view name) {
auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
if (!maybe_workspace.has_value()) {
return absl::Status(
absl::StatusCode::kResourceExhausted,
absl::StrFormat("Unable to allocate workspace for %s", name));
}
return static_cast<T*>(maybe_workspace.value());
}
#if JAX_GPU_HAVE_64_BIT
// Map an FFI buffer element type to the appropriate GPU solver type.

View File

@ -1683,10 +1683,13 @@ class ScipyLinalgTest(jtu.JaxTestCase):
mode=["full", "r", "economic"],
pivoting=[False, True]
)
@jax.default_matmul_precision("float32")
def testScipyQrModes(self, shape, dtype, mode, pivoting):
is_not_cpu_test_device = not jtu.test_device_matches(["cpu"])
if pivoting and is_not_cpu_test_device:
self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38")
if pivoting:
if not jtu.test_device_matches(["cpu", "gpu"]):
self.skipTest("Pivoting is only supported on CPU and GPU.")
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0")
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)
sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting)
@ -1700,7 +1703,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def qr_and_mul(a):
q, r, *p = jsp_func(a)
# To express the identity function we must "undo" the pivoting of `q @ r`.
inverted_pivots = p[0][p[0]]
inverted_pivots = jnp.argsort(p[0])
return (q @ r)[:, inverted_pivots]
m, n = shape

View File

@ -114,5 +114,34 @@ class MagmaLinalgTest(jtu.JaxTestCase):
hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text()
self.assertIn('magma = "on"', hlo)
@jtu.sample_product(
shape=[(3, 4), (3, 3), (4, 3), (4, 3)],
dtype=float_types + complex_types,
)
@jtu.run_on_devices("gpu")
def testPivotedQrFactorization(self, shape, dtype):
if jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())
lax_func = partial(lax_linalg.qr, full_matrices=True, pivoting=True, use_magma=True)
sp_func = partial(jax.scipy.linalg.qr, mode="full", pivoting=True)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, lax_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(lax_func, args_maker)
def testPivotedQrFactorizationMagmaConfig(self):
if jtu.jaxlib_version() <= (0, 5, 0):
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
if not gpu_solver.has_magma():
self.skipTest("MAGMA is not installed or can't be loaded.")
rng = jtu.rand_default(self.rng())
a = rng((5, 5), np.float32)
with config.gpu_use_magma("on"):
hlo = jax.jit(partial(lax_linalg.qr, pivoting=True, use_magma=True)).lower(a).as_text()
self.assertIn('magma = "on"', hlo)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())