mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
This commit is contained in:
parent
e14466a8fb
commit
b1b56ea0b0
@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
|
||||
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
|
||||
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
|
||||
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
|
||||
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
|
||||
{jax-issue}`#25955` for more details.
|
||||
|
||||
* Changes
|
||||
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
|
||||
|
@ -311,20 +311,22 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:
|
||||
|
||||
@overload
|
||||
def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True,
|
||||
) -> tuple[Array, Array]:
|
||||
use_magma: bool | None = None) -> tuple[Array, Array]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True,
|
||||
) -> tuple[Array, Array, Array]:
|
||||
use_magma: bool | None = None) -> tuple[Array, Array, Array]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
|
||||
use_magma: bool | None = None
|
||||
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
|
||||
...
|
||||
|
||||
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
|
||||
use_magma: bool | None = None
|
||||
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
|
||||
"""QR decomposition.
|
||||
|
||||
@ -341,9 +343,14 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
|
||||
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``,
|
||||
compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P``
|
||||
is chosen such that the diagonal of ``R`` is non-increasing. Currently
|
||||
supported on CPU backends only.
|
||||
supported on CPU and GPU backends only.
|
||||
full_matrices: Determines if full or reduced matrices are returned; see
|
||||
below.
|
||||
use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
|
||||
pivoted `qr` factorization is computed using MAGMA. If ``False``, the
|
||||
computation is done using LAPACK on the host CPU. If ``None`` (default),
|
||||
the behavior is controlled by the ``jax_use_magma`` flag. This argument is
|
||||
only used on GPU.
|
||||
|
||||
Returns:
|
||||
A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``.
|
||||
@ -357,8 +364,16 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
|
||||
``full_matrices=False``.
|
||||
|
||||
Array ``p`` is an index vector with shape [..., n]
|
||||
|
||||
Notes:
|
||||
- `MAGMA <https://icl.utk.edu/magma/>`_ support is experimental - see
|
||||
:func:`jax.lax.linalg.eig` for further assumptions and limitations.
|
||||
- If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will
|
||||
be used if the library can be found, and the input matrix is sufficiently
|
||||
large (has at least 2048 columns).
|
||||
"""
|
||||
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices)
|
||||
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
|
||||
use_magma=use_magma)
|
||||
if pivoting:
|
||||
return q, r, p[0]
|
||||
return q, r
|
||||
@ -1854,22 +1869,28 @@ mlir.register_lowering(
|
||||
platform='rocm')
|
||||
|
||||
|
||||
def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]:
|
||||
def geqp3(a: ArrayLike, jpvt: ArrayLike, *,
|
||||
use_magma: bool | None = None) -> tuple[Array, Array, Array]:
|
||||
"""Computes the column-pivoted QR decomposition of a matrix.
|
||||
|
||||
Args:
|
||||
a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type.
|
||||
jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type,
|
||||
use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
|
||||
`geqp3` is computed using MAGMA. If ``False``, the computation is done using
|
||||
LAPACK on to the host CPU. If ``None`` (default), the behavior is controlled
|
||||
by the ``jax_use_magma`` flag. This argument is only used on GPU.
|
||||
Returns:
|
||||
A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``,
|
||||
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
|
||||
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
|
||||
such that ``a[:, jpvt] = q @ r``.
|
||||
"""
|
||||
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt)
|
||||
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma)
|
||||
return a_out, jpvt_out, taus
|
||||
|
||||
def _geqp3_abstract_eval(a, jpvt):
|
||||
def _geqp3_abstract_eval(a, jpvt, *, use_magma):
|
||||
del use_magma
|
||||
if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray):
|
||||
raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: "
|
||||
f"{a.aval}, {jpvt.aval}")
|
||||
@ -1882,25 +1903,37 @@ def _geqp3_abstract_eval(a, jpvt):
|
||||
taus = a.update(shape=(*batch_dims, core.min_dim(m, n)))
|
||||
return a, jpvt, taus
|
||||
|
||||
def _geqp3_batching_rule(batched_args, batch_dims):
|
||||
def _geqp3_batching_rule(batched_args, batch_dims, *, use_magma):
|
||||
a, jpvt = batched_args
|
||||
b_a, b_jpvt = batch_dims
|
||||
a = batching.moveaxis(a, b_a, 0)
|
||||
jpvt = batching.moveaxis(jpvt, b_jpvt, 0)
|
||||
return geqp3(a, jpvt), (0, 0, 0)
|
||||
return geqp3(a, jpvt, use_magma=use_magma), (0, 0, 0)
|
||||
|
||||
def _geqp3_cpu_lowering(ctx, a, jpvt):
|
||||
def _geqp3_cpu_lowering(ctx, a, jpvt, *, use_magma):
|
||||
del use_magma
|
||||
a_aval, _ = ctx.avals_in
|
||||
target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype)
|
||||
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
|
||||
return rule(ctx, a, jpvt)
|
||||
|
||||
def _geqp3_gpu_lowering(target_name_prefix, ctx, a, jpvt, *, use_magma):
|
||||
gpu_solver.initialize_hybrid_kernels()
|
||||
magma = config.gpu_use_magma.value
|
||||
target_name = f"{target_name_prefix}hybrid_geqp3"
|
||||
if use_magma is not None:
|
||||
magma = "on" if use_magma else "off"
|
||||
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
|
||||
return rule(ctx, a, jpvt, magma=magma)
|
||||
|
||||
geqp3_p = Primitive('geqp3')
|
||||
geqp3_p.multiple_results = True
|
||||
geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p))
|
||||
geqp3_p.def_abstract_eval(_geqp3_abstract_eval)
|
||||
batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule
|
||||
mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu")
|
||||
mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'cu'), platform="cuda")
|
||||
mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'hip'), platform="rocm")
|
||||
|
||||
# householder_product: product of elementary Householder reflectors
|
||||
|
||||
@ -1988,12 +2021,13 @@ mlir.register_lowering(
|
||||
platform='rocm')
|
||||
|
||||
|
||||
def _qr_impl(operand, *, pivoting, full_matrices):
|
||||
def _qr_impl(operand, *, pivoting, full_matrices, use_magma):
|
||||
q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting,
|
||||
full_matrices=full_matrices)
|
||||
full_matrices=full_matrices, use_magma=use_magma)
|
||||
return (q, r, p[0]) if pivoting else (q, r)
|
||||
|
||||
def _qr_abstract_eval(operand, *, pivoting, full_matrices):
|
||||
def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
|
||||
del use_magma
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
@ -2018,11 +2052,11 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
|
||||
q, r, p = operand, operand, operand
|
||||
return (q, r, p) if pivoting else (q, r)
|
||||
|
||||
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
|
||||
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma):
|
||||
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
|
||||
x, = primals
|
||||
dx, = tangents
|
||||
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False)
|
||||
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False, use_magma=use_magma)
|
||||
*_, m, n = x.shape
|
||||
if m < n or (full_matrices and m != n):
|
||||
raise NotImplementedError(
|
||||
@ -2043,14 +2077,16 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
|
||||
return (q, r, p[0]), (dq, dr, dp)
|
||||
return (q, r), (dq, dr)
|
||||
|
||||
def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices):
|
||||
def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices,
|
||||
use_magma):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
out_axes = (0, 0, 0) if pivoting else (0, 0)
|
||||
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes
|
||||
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
|
||||
use_magma=use_magma), out_axes
|
||||
|
||||
def _qr_lowering(a, *, pivoting, full_matrices):
|
||||
def _qr_lowering(a, *, pivoting, full_matrices, use_magma):
|
||||
*batch_dims, m, n = a.shape
|
||||
if m == 0 or n == 0:
|
||||
k = m if full_matrices else core.min_dim(m, n)
|
||||
@ -2065,7 +2101,7 @@ def _qr_lowering(a, *, pivoting, full_matrices):
|
||||
|
||||
if pivoting:
|
||||
jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
|
||||
r, p, taus = geqp3(a, jpvt)
|
||||
r, p, taus = geqp3(a, jpvt, use_magma=use_magma)
|
||||
p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1.
|
||||
else:
|
||||
r, taus = geqrf(a)
|
||||
|
@ -953,7 +953,9 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "
|
||||
with ``K = min(M, N)``.
|
||||
|
||||
Notes:
|
||||
- At present, pivoting is only implemented on CPU backends.
|
||||
- At present, pivoting is only implemented on the CPU and GPU backends. For further
|
||||
details about the GPU implementation, see the documentation for
|
||||
:func:`jax.lax.linalg.qr`.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API
|
||||
|
@ -184,6 +184,19 @@ auto AllocateScratchMemory(std::size_t size)
|
||||
return std::unique_ptr<ValueType[]>(new ValueType[size]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline absl::StatusOr<T*> AllocateWorkspace(
|
||||
::xla::ffi::ScratchAllocator& scratch, int64_t size,
|
||||
std::string_view name) {
|
||||
auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
|
||||
if (!maybe_workspace.has_value()) {
|
||||
return absl::Status(
|
||||
absl::StatusCode::kResourceExhausted,
|
||||
absl::StrFormat("Unable to allocate workspace for %s", name));
|
||||
}
|
||||
return static_cast<T*>(maybe_workspace.value());
|
||||
}
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_FFI_HELPERS_H_
|
||||
|
@ -47,6 +47,10 @@ void GetLapackKernelsFromScipy() {
|
||||
lapack_ptr("cgeev"));
|
||||
AssignKernelFn<EigenvalueDecompositionComplex<ffi::C128>>(
|
||||
lapack_ptr("zgeev"));
|
||||
AssignKernelFn<PivotingQrFactorization<ffi::F32>>(lapack_ptr("sgeqp3"));
|
||||
AssignKernelFn<PivotingQrFactorization<ffi::F64>>(lapack_ptr("dgeqp3"));
|
||||
AssignKernelFn<PivotingQrFactorization<ffi::C64>>(lapack_ptr("cgeqp3"));
|
||||
AssignKernelFn<PivotingQrFactorization<ffi::C128>>(lapack_ptr("zgeqp3"));
|
||||
});
|
||||
}
|
||||
|
||||
@ -57,6 +61,7 @@ NB_MODULE(_hybrid, m) {
|
||||
nb::dict dict;
|
||||
dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal);
|
||||
dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp);
|
||||
dict[JAX_GPU_PREFIX "hybrid_geqp3"] = EncapsulateFfiHandler(kGeqp3);
|
||||
return dict;
|
||||
});
|
||||
}
|
||||
|
@ -103,6 +103,30 @@ template <>
|
||||
struct MagmaGeev<ffi::C128> {
|
||||
static constexpr char name[] = "magma_zgeev";
|
||||
};
|
||||
template <ffi::DataType DataType>
|
||||
struct MagmaGeqp3 {
|
||||
static_assert(always_false<DataType>::value, "unsupported data type");
|
||||
};
|
||||
template <>
|
||||
struct MagmaGeqp3<ffi::F32> {
|
||||
static constexpr char name[] = "magma_sgeqp3_gpu";
|
||||
static constexpr char block_size_name[] = "magma_get_sgeqp3_nb";
|
||||
};
|
||||
template <>
|
||||
struct MagmaGeqp3<ffi::F64> {
|
||||
static constexpr char name[] = "magma_dgeqp3_gpu";
|
||||
static constexpr char block_size_name[] = "magma_get_dgeqp3_nb";
|
||||
};
|
||||
template <>
|
||||
struct MagmaGeqp3<ffi::C64> {
|
||||
static constexpr char name[] = "magma_cgeqp3_gpu";
|
||||
static constexpr char block_size_name[] = "magma_get_cgeqp3_nb";
|
||||
};
|
||||
template <>
|
||||
struct MagmaGeqp3<ffi::C128> {
|
||||
static constexpr char name[] = "magma_zgeqp3_gpu";
|
||||
static constexpr char block_size_name[] = "magma_get_zgeqp3_nb";
|
||||
};
|
||||
|
||||
MagmaLookup::~MagmaLookup() {
|
||||
if (initialized_) {
|
||||
@ -205,6 +229,245 @@ absl::StatusOr<void*> FindMagmaSymbol(const char name[]) {
|
||||
return lookup.Find(name);
|
||||
}
|
||||
|
||||
// Column Pivoting QR Factorization
|
||||
|
||||
// magma geqp3_gpu
|
||||
|
||||
template <ffi::DataType DataType>
|
||||
class PivotingQrFactorizationHost {
|
||||
using RealType = ffi::NativeType<ffi::ToReal(DataType)>;
|
||||
using ValueType = ffi::NativeType<DataType>;
|
||||
|
||||
public:
|
||||
explicit PivotingQrFactorizationHost() = default;
|
||||
PivotingQrFactorizationHost(PivotingQrFactorizationHost&&) = default;
|
||||
|
||||
ffi::Error compute(int64_t batch, int64_t rows, int64_t cols,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
ffi::AnyBuffer x, ffi::AnyBuffer jpvt,
|
||||
ffi::Result<ffi::AnyBuffer> x_out,
|
||||
ffi::Result<ffi::AnyBuffer> jpvt_out,
|
||||
ffi::Result<ffi::AnyBuffer> tau) {
|
||||
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
auto min_dim = std::min(m, n);
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n));
|
||||
auto work = AllocateScratchMemory<DataType>(lwork);
|
||||
|
||||
constexpr bool is_complex_dtype = ffi::IsComplexType<DataType>();
|
||||
std::unique_ptr<RealType[]> rwork;
|
||||
if constexpr (is_complex_dtype) {
|
||||
rwork = AllocateScratchMemory<ffi::ToReal(DataType)>(2 * n);
|
||||
}
|
||||
|
||||
auto x_host = HostBuffer<ValueType>(x.element_count());
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
x_host.CopyFromDevice(stream, x.typed_data<ValueType>()));
|
||||
auto jpvt_host = HostBuffer<int>(jpvt.element_count());
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
jpvt_host.CopyFromDevice(stream, jpvt.typed_data<int>()));
|
||||
auto tau_host = HostBuffer<ValueType>(batch * min_dim);
|
||||
auto info_host = HostBuffer<int>(batch);
|
||||
|
||||
for (int64_t i = 0; i < batch; ++i) {
|
||||
if constexpr (is_complex_dtype) {
|
||||
PivotingQrFactorization<DataType>::fn(
|
||||
&m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n,
|
||||
tau_host.get() + i * min_dim, work.get(), &lwork, rwork.get(),
|
||||
info_host.get() + i);
|
||||
} else {
|
||||
PivotingQrFactorization<DataType>::fn(
|
||||
&m, &n, x_host.get() + i * m * n, &m, jpvt_host.get() + i * n,
|
||||
tau_host.get() + i * min_dim, work.get(), &lwork,
|
||||
info_host.get() + i);
|
||||
}
|
||||
}
|
||||
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
x_host.CopyToDevice(stream, x_out->typed_data<ValueType>()));
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
jpvt_host.CopyToDevice(stream, jpvt_out->typed_data<int>()));
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
tau_host.CopyToDevice(stream, tau->typed_data<ValueType>()));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<int> lwork(int m, int n) {
|
||||
int64_t lwork = PivotingQrFactorization<DataType>::GetWorkspaceSize(m, n);
|
||||
return MaybeCastNoOverflow<int>(lwork);
|
||||
}
|
||||
};
|
||||
|
||||
template <ffi::DataType DataType>
|
||||
class PivotingQrFactorizationMagma {
|
||||
using RealType = ffi::NativeType<ffi::ToReal(DataType)>;
|
||||
using ValueType = ffi::NativeType<DataType>;
|
||||
using Fn = std::conditional_t<
|
||||
ffi::IsComplexType<DataType>(),
|
||||
int(int m, int n, ValueType* dA, int ldda, int* jpvt, ValueType* tau,
|
||||
ValueType* dwork, int lwork, RealType* rwork, int* info),
|
||||
int(int m, int n, RealType* dA, int ldda, int* jpvt, RealType* tau,
|
||||
RealType* dwork, int lwork, int* info)>;
|
||||
using BlockSizeFn = int(int m, int n);
|
||||
|
||||
public:
|
||||
explicit PivotingQrFactorizationMagma() = default;
|
||||
PivotingQrFactorizationMagma(PivotingQrFactorizationMagma&&) = default;
|
||||
|
||||
ffi::Error compute(int64_t batch, int64_t rows, int64_t cols,
|
||||
gpuStream_t stream, ffi::ScratchAllocator& scratch,
|
||||
ffi::AnyBuffer x, ffi::AnyBuffer jpvt,
|
||||
ffi::Result<ffi::AnyBuffer> x_out,
|
||||
ffi::Result<ffi::AnyBuffer> jpvt_out,
|
||||
ffi::Result<ffi::AnyBuffer> tau) {
|
||||
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
|
||||
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
|
||||
auto min_dim = std::min(m, n);
|
||||
|
||||
FFI_ASSIGN_OR_RETURN(int lwork, lwork(m, n));
|
||||
FFI_ASSIGN_OR_RETURN(auto work,
|
||||
AllocateWorkspace<ValueType>(scratch, lwork, "geqp3"));
|
||||
|
||||
constexpr bool is_complex_dtype = ffi::IsComplexType<DataType>();
|
||||
RealType* rwork;
|
||||
if constexpr (is_complex_dtype) {
|
||||
FFI_ASSIGN_OR_RETURN(
|
||||
rwork, AllocateWorkspace<RealType>(scratch, 2 * n, "geqp3"));
|
||||
}
|
||||
|
||||
auto x_data = x.typed_data<ValueType>();
|
||||
auto x_out_data = x_out->typed_data<ValueType>();
|
||||
auto tau_data = tau->typed_data<ValueType>();
|
||||
if (x_data != x_out_data) {
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
JAX_AS_STATUS(gpuMemcpyAsync(x_out_data, x_data, x.size_bytes(),
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
auto jpvt_host = HostBuffer<int>(jpvt.element_count());
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
jpvt_host.CopyFromDevice(stream, jpvt.typed_data<int>()));
|
||||
auto info_host = HostBuffer<int>(batch);
|
||||
|
||||
// TODO: do we need to wrap with synchronise due to non-stream safety.
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
for (int64_t i = 0; i < batch; ++i) {
|
||||
if constexpr (is_complex_dtype) {
|
||||
fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n,
|
||||
tau_data + i * min_dim, work, lwork, rwork, info_host.get() + i);
|
||||
} else {
|
||||
fn_(m, n, x_out_data + i * m * n, m, jpvt_host.get() + i * n,
|
||||
tau_data + i * min_dim, work, lwork, info_host.get() + i);
|
||||
}
|
||||
}
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
FFI_RETURN_IF_ERROR_STATUS(
|
||||
jpvt_host.CopyToDevice(stream, jpvt_out->typed_data<int>()));
|
||||
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
private:
|
||||
Fn* fn_ = nullptr;
|
||||
BlockSizeFn* block_size_fn_ = nullptr;
|
||||
|
||||
absl::StatusOr<int> lwork(int m, int n) {
|
||||
// `{c,d,s,z}_geqp3_gpu` do not support a workspace query, but we can still
|
||||
// assign the symbol here.
|
||||
auto maybe_ptr = FindMagmaSymbol(MagmaGeqp3<DataType>::name);
|
||||
if (!maybe_ptr.ok()) return maybe_ptr.status();
|
||||
fn_ = reinterpret_cast<Fn*>(*maybe_ptr);
|
||||
|
||||
auto block_size_maybe_ptr =
|
||||
FindMagmaSymbol(MagmaGeqp3<DataType>::block_size_name);
|
||||
if (!block_size_maybe_ptr.ok()) return block_size_maybe_ptr.status();
|
||||
block_size_fn_ = reinterpret_cast<BlockSizeFn*>(*block_size_maybe_ptr);
|
||||
int optimal_block_size = block_size_fn_(m, n);
|
||||
if constexpr (ffi::IsComplexType<DataType>()) {
|
||||
return (n + 1) * optimal_block_size;
|
||||
}
|
||||
return (n + 1) * optimal_block_size + 2 * n;
|
||||
}
|
||||
};
|
||||
|
||||
ffi::Error PivotingQrFactorizationDispatch(
|
||||
gpuStream_t stream, ffi::ScratchAllocator scratch, std::string_view magma,
|
||||
ffi::AnyBuffer x, ffi::AnyBuffer jpvt, ffi::Result<ffi::AnyBuffer> x_out,
|
||||
ffi::Result<ffi::AnyBuffer> jpvt_out, ffi::Result<ffi::AnyBuffer> tau) {
|
||||
auto dataType = x.element_type();
|
||||
if (dataType != x_out->element_type() || dataType != tau->element_type()) {
|
||||
return ffi::Error::InvalidArgument(
|
||||
"The buffers 'x', 'x_out' and 'tau' must have the same element type.");
|
||||
}
|
||||
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
|
||||
SplitBatch2D(x.dimensions()));
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(jpvt.dimensions(), {batch, cols}, "jpvt", "geqp3"));
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(x_out->dimensions(), {batch, rows, cols}, "x_out", "geqp3"));
|
||||
FFI_RETURN_IF_ERROR(
|
||||
CheckShape(jpvt_out->dimensions(), {batch, cols}, "jpvt_out", "geqp3"));
|
||||
FFI_RETURN_IF_ERROR(CheckShape(
|
||||
tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqp3"));
|
||||
|
||||
bool use_magma = magma == "on";
|
||||
if (magma == "auto" && cols >= 2048) {
|
||||
use_magma = FindMagmaSymbol("magma_init").ok();
|
||||
}
|
||||
|
||||
switch (dataType) {
|
||||
case ffi::F32:
|
||||
if (use_magma) {
|
||||
return PivotingQrFactorizationMagma<ffi::F32>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
} else {
|
||||
return PivotingQrFactorizationHost<ffi::F32>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
}
|
||||
case ffi::F64:
|
||||
if (use_magma) {
|
||||
return PivotingQrFactorizationMagma<ffi::F64>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
} else {
|
||||
return PivotingQrFactorizationHost<ffi::F64>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
}
|
||||
case ffi::C64:
|
||||
if (use_magma) {
|
||||
return PivotingQrFactorizationMagma<ffi::C64>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
} else {
|
||||
return PivotingQrFactorizationHost<ffi::C64>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
}
|
||||
case ffi::C128:
|
||||
if (use_magma) {
|
||||
return PivotingQrFactorizationMagma<ffi::C128>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
} else {
|
||||
return PivotingQrFactorizationHost<ffi::C128>().compute(
|
||||
batch, rows, cols, stream, scratch, x, jpvt, x_out, jpvt_out, tau);
|
||||
}
|
||||
default:
|
||||
return ffi::Error::InvalidArgument(absl::StrFormat(
|
||||
"Unsupported dtype %s in geqp3", absl::FormatStreamed(dataType)));
|
||||
}
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(kGeqp3, PivotingQrFactorizationDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<gpuStream_t>>()
|
||||
.Ctx<ffi::ScratchAllocator>()
|
||||
.Attr<std::string_view>("magma")
|
||||
.Arg<ffi::AnyBuffer>() // x
|
||||
.Arg<ffi::AnyBuffer>() // jpvt
|
||||
.Ret<ffi::AnyBuffer>() // x_out
|
||||
.Ret<ffi::AnyBuffer>() // jpvt_out
|
||||
.Ret<ffi::AnyBuffer>() // tau
|
||||
);
|
||||
|
||||
// Real-valued eigendecomposition
|
||||
|
||||
template <ffi::DataType DataType>
|
||||
|
@ -48,6 +48,7 @@ class MagmaLookup {
|
||||
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(kGeqp3);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
@ -51,19 +51,6 @@ namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
namespace ffi = ::xla::ffi;
|
||||
|
||||
template <typename T>
|
||||
inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
|
||||
int64_t size,
|
||||
std::string_view name) {
|
||||
auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
|
||||
if (!maybe_workspace.has_value()) {
|
||||
return absl::Status(
|
||||
absl::StatusCode::kResourceExhausted,
|
||||
absl::StrFormat("Unable to allocate workspace for %s", name));
|
||||
}
|
||||
return static_cast<T*>(maybe_workspace.value());
|
||||
}
|
||||
|
||||
#if JAX_GPU_HAVE_64_BIT
|
||||
|
||||
// Map an FFI buffer element type to the appropriate GPU solver type.
|
||||
|
@ -1683,10 +1683,13 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
mode=["full", "r", "economic"],
|
||||
pivoting=[False, True]
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testScipyQrModes(self, shape, dtype, mode, pivoting):
|
||||
is_not_cpu_test_device = not jtu.test_device_matches(["cpu"])
|
||||
if pivoting and is_not_cpu_test_device:
|
||||
self.skipTest("Pivoting is only supported on CPU with jaxlib > 0.4.38")
|
||||
if pivoting:
|
||||
if not jtu.test_device_matches(["cpu", "gpu"]):
|
||||
self.skipTest("Pivoting is only supported on CPU and GPU.")
|
||||
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 5, 0):
|
||||
self.skipTest("Pivoting is only supported on GPU for jaxlib > 0.5.0")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jsp_func = partial(jax.scipy.linalg.qr, mode=mode, pivoting=pivoting)
|
||||
sp_func = partial(scipy.linalg.qr, mode=mode, pivoting=pivoting)
|
||||
@ -1700,7 +1703,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def qr_and_mul(a):
|
||||
q, r, *p = jsp_func(a)
|
||||
# To express the identity function we must "undo" the pivoting of `q @ r`.
|
||||
inverted_pivots = p[0][p[0]]
|
||||
inverted_pivots = jnp.argsort(p[0])
|
||||
return (q @ r)[:, inverted_pivots]
|
||||
|
||||
m, n = shape
|
||||
|
@ -114,5 +114,34 @@ class MagmaLinalgTest(jtu.JaxTestCase):
|
||||
hlo = jax.jit(partial(lax_linalg.eig, use_magma=True)).lower(a).as_text()
|
||||
self.assertIn('magma = "on"', hlo)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(3, 4), (3, 3), (4, 3), (4, 3)],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testPivotedQrFactorization(self, shape, dtype):
|
||||
if jtu.jaxlib_version() <= (0, 5, 0):
|
||||
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
|
||||
if not gpu_solver.has_magma():
|
||||
self.skipTest("MAGMA is not installed or can't be loaded.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
lax_func = partial(lax_linalg.qr, full_matrices=True, pivoting=True, use_magma=True)
|
||||
sp_func = partial(jax.scipy.linalg.qr, mode="full", pivoting=True)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(sp_func, lax_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||
self._CompileAndCheck(lax_func, args_maker)
|
||||
|
||||
def testPivotedQrFactorizationMagmaConfig(self):
|
||||
if jtu.jaxlib_version() <= (0, 5, 0):
|
||||
self.skipTest("qr with `pivoting=True` on GPU requires jaxlib version > 0.5.0")
|
||||
if not gpu_solver.has_magma():
|
||||
self.skipTest("MAGMA is not installed or can't be loaded.")
|
||||
rng = jtu.rand_default(self.rng())
|
||||
a = rng((5, 5), np.float32)
|
||||
with config.gpu_use_magma("on"):
|
||||
hlo = jax.jit(partial(lax_linalg.qr, pivoting=True, use_magma=True)).lower(a).as_text()
|
||||
self.assertIn('magma = "on"', hlo)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user