mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use Jacobi algorithm for symmetric eigendecomposition for matrices with n < 32.
Use the batched Jacobi algorithm for large batches of small matrices.
This commit is contained in:
parent
d6bd59d716
commit
7160077cad
@ -314,7 +314,7 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition: syevd/heevd
|
||||
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
|
||||
|
||||
struct SyevdDescriptor {
|
||||
Type type;
|
||||
@ -424,6 +424,171 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
struct SyevjDescriptor {
|
||||
Type type;
|
||||
cublasFillMode_t uplo;
|
||||
int batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
// Returns the workspace size and a descriptor for a syevj_batched operation.
|
||||
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
|
||||
bool lower, int batch, int n) {
|
||||
Type type = DtypeToType(dtype);
|
||||
auto handle = SolverHandlePool::Borrow();
|
||||
int lwork;
|
||||
syevjInfo_t params;
|
||||
ThrowIfErrorStatus(cusolverDnCreateSyevjInfo(¶ms));
|
||||
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
|
||||
params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); });
|
||||
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
|
||||
cublasFillMode_t uplo =
|
||||
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
|
||||
if (batch == 1) {
|
||||
switch (type) {
|
||||
case Type::F32:
|
||||
ThrowIfErrorStatus(cusolverDnSsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params));
|
||||
break;
|
||||
case Type::F64:
|
||||
ThrowIfErrorStatus(cusolverDnDsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params));
|
||||
break;
|
||||
case Type::C64:
|
||||
ThrowIfErrorStatus(cusolverDnCheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params));
|
||||
break;
|
||||
case Type::C128:
|
||||
ThrowIfErrorStatus(cusolverDnZheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (type) {
|
||||
case Type::F32:
|
||||
ThrowIfErrorStatus(cusolverDnSsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch));
|
||||
break;
|
||||
case Type::F64:
|
||||
ThrowIfErrorStatus(cusolverDnDsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch));
|
||||
break;
|
||||
case Type::C64:
|
||||
ThrowIfErrorStatus(cusolverDnCheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch));
|
||||
break;
|
||||
case Type::C128:
|
||||
ThrowIfErrorStatus(cusolverDnZheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})};
|
||||
}
|
||||
|
||||
void Syevj(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
const SyevjDescriptor& d =
|
||||
*UnpackDescriptor<SyevjDescriptor>(opaque, opaque_len);
|
||||
auto handle = SolverHandlePool::Borrow(stream);
|
||||
if (buffers[1] != buffers[0]) {
|
||||
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
|
||||
SizeOfType(d.type) * d.batch * d.n * d.n,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
syevjInfo_t params;
|
||||
ThrowIfErrorStatus(cusolverDnCreateSyevjInfo(¶ms));
|
||||
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
|
||||
params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); });
|
||||
|
||||
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* work = buffers[4];
|
||||
if (d.batch == 1) {
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
ThrowIfErrorStatus(cusolverDnSsyevj(handle.get(), jobz, d.uplo, d.n, a,
|
||||
d.n, w, static_cast<float*>(work),
|
||||
d.lwork, info, params));
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
ThrowIfErrorStatus(cusolverDnDsyevj(handle.get(), jobz, d.uplo, d.n, a,
|
||||
d.n, w, static_cast<double*>(work),
|
||||
d.lwork, info, params));
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
ThrowIfErrorStatus(cusolverDnCheevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<cuComplex*>(work), d.lwork, info, params));
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
ThrowIfErrorStatus(cusolverDnZheevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<cuDoubleComplex*>(work), d.lwork, info, params));
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (d.type) {
|
||||
case Type::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
ThrowIfErrorStatus(cusolverDnSsyevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<float*>(work), d.lwork, info, params, d.batch));
|
||||
break;
|
||||
}
|
||||
case Type::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
ThrowIfErrorStatus(cusolverDnDsyevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<double*>(work), d.lwork, info, params, d.batch));
|
||||
break;
|
||||
}
|
||||
case Type::C64: {
|
||||
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
ThrowIfErrorStatus(cusolverDnCheevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<cuComplex*>(work), d.lwork, info, params, d.batch));
|
||||
break;
|
||||
}
|
||||
case Type::C128: {
|
||||
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<cuDoubleComplex*>(work),
|
||||
d.lwork, info, params, d.batch));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Singular value decomposition: gesvd
|
||||
|
||||
struct GesvdDescriptor {
|
||||
@ -500,7 +665,6 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
u += d.m * d.m;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -528,10 +692,9 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
cuComplex* u = static_cast<cuComplex*>(buffers[3]);
|
||||
cuComplex* vt = static_cast<cuComplex*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
ThrowIfErrorStatus(
|
||||
cusolverDnCgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s,
|
||||
u, d.m, vt, d.n, static_cast<cuComplex*>(work),
|
||||
d.lwork, /*rwork=*/nullptr, info));
|
||||
ThrowIfErrorStatus(cusolverDnCgesvd(
|
||||
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
|
||||
static_cast<cuComplex*>(work), d.lwork, /*rwork=*/nullptr, info));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
@ -570,6 +733,7 @@ py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
|
||||
dict["cusolver_syevd"] = EncapsulateFunction(Syevd);
|
||||
dict["cusolver_syevj"] = EncapsulateFunction(Syevj);
|
||||
dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd);
|
||||
return dict;
|
||||
}
|
||||
@ -578,6 +742,7 @@ PYBIND11_MODULE(cusolver_kernels, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
|
||||
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
|
||||
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
|
||||
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
|
||||
}
|
||||
|
||||
|
@ -97,12 +97,18 @@ def syevd(c, a, lower=False):
|
||||
b *= d
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
lwork, opaque = cusolver_kernels.build_syevd_descriptor(
|
||||
np.dtype(dtype), lower, b, n)
|
||||
if n <= 32:
|
||||
kernel = b"cusolver_syevj"
|
||||
lwork, opaque = cusolver_kernels.build_syevj_descriptor(
|
||||
np.dtype(dtype), lower, b, n)
|
||||
else:
|
||||
kernel = b"cusolver_syevd"
|
||||
lwork, opaque = cusolver_kernels.build_syevd_descriptor(
|
||||
np.dtype(dtype), lower, b, n)
|
||||
eigvals_type = _real_type(dtype)
|
||||
|
||||
out = c.CustomCall(
|
||||
b"cusolver_syevd",
|
||||
kernel,
|
||||
operands=(a,),
|
||||
shape_with_layout=_Shape.tuple_shape((
|
||||
_Shape.array_shape(dtype, dims, layout),
|
||||
|
Loading…
x
Reference in New Issue
Block a user