Use Jacobi algorithm for symmetric eigendecomposition for matrices with n < 32.

Use the batched Jacobi algorithm for large batches of small matrices.
This commit is contained in:
Peter Hawkins 2019-08-07 11:33:48 -04:00
parent d6bd59d716
commit 7160077cad
2 changed files with 180 additions and 9 deletions

View File

@ -314,7 +314,7 @@ void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
}
}
// Symmetric (Hermitian) eigendecomposition: syevd/heevd
// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd
struct SyevdDescriptor {
Type type;
@ -424,6 +424,171 @@ void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
}
}
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
struct SyevjDescriptor {
Type type;
cublasFillMode_t uplo;
int batch, n;
int lwork;
};
// Returns the workspace size and a descriptor for a syevj_batched operation.
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
bool lower, int batch, int n) {
Type type = DtypeToType(dtype);
auto handle = SolverHandlePool::Borrow();
int lwork;
syevjInfo_t params;
ThrowIfErrorStatus(cusolverDnCreateSyevjInfo(&params));
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); });
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
cublasFillMode_t uplo =
lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
if (batch == 1) {
switch (type) {
case Type::F32:
ThrowIfErrorStatus(cusolverDnSsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params));
break;
case Type::F64:
ThrowIfErrorStatus(cusolverDnDsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params));
break;
case Type::C64:
ThrowIfErrorStatus(cusolverDnCheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params));
break;
case Type::C128:
ThrowIfErrorStatus(cusolverDnZheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params));
break;
}
} else {
switch (type) {
case Type::F32:
ThrowIfErrorStatus(cusolverDnSsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch));
break;
case Type::F64:
ThrowIfErrorStatus(cusolverDnDsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch));
break;
case Type::C64:
ThrowIfErrorStatus(cusolverDnCheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch));
break;
case Type::C128:
ThrowIfErrorStatus(cusolverDnZheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch));
break;
}
}
return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})};
}
void Syevj(cudaStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
const SyevjDescriptor& d =
*UnpackDescriptor<SyevjDescriptor>(opaque, opaque_len);
auto handle = SolverHandlePool::Borrow(stream);
if (buffers[1] != buffers[0]) {
ThrowIfError(cudaMemcpyAsync(buffers[1], buffers[0],
SizeOfType(d.type) * d.batch * d.n * d.n,
cudaMemcpyDeviceToDevice, stream));
}
syevjInfo_t params;
ThrowIfErrorStatus(cusolverDnCreateSyevjInfo(&params));
std::unique_ptr<syevjInfo, void (*)(syevjInfo*)> params_cleanup(
params, [](syevjInfo* p) { cusolverDnDestroySyevjInfo(p); });
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
if (d.batch == 1) {
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
ThrowIfErrorStatus(cusolverDnSsyevj(handle.get(), jobz, d.uplo, d.n, a,
d.n, w, static_cast<float*>(work),
d.lwork, info, params));
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
ThrowIfErrorStatus(cusolverDnDsyevj(handle.get(), jobz, d.uplo, d.n, a,
d.n, w, static_cast<double*>(work),
d.lwork, info, params));
break;
}
case Type::C64: {
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
ThrowIfErrorStatus(cusolverDnCheevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<cuComplex*>(work), d.lwork, info, params));
break;
}
case Type::C128: {
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
ThrowIfErrorStatus(cusolverDnZheevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<cuDoubleComplex*>(work), d.lwork, info, params));
break;
}
}
} else {
switch (d.type) {
case Type::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
ThrowIfErrorStatus(cusolverDnSsyevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<float*>(work), d.lwork, info, params, d.batch));
break;
}
case Type::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
ThrowIfErrorStatus(cusolverDnDsyevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<double*>(work), d.lwork, info, params, d.batch));
break;
}
case Type::C64: {
cuComplex* a = static_cast<cuComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
ThrowIfErrorStatus(cusolverDnCheevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<cuComplex*>(work), d.lwork, info, params, d.batch));
break;
}
case Type::C128: {
cuDoubleComplex* a = static_cast<cuDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
ThrowIfErrorStatus(
cusolverDnZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<cuDoubleComplex*>(work),
d.lwork, info, params, d.batch));
break;
}
}
}
}
// Singular value decomposition: gesvd
struct GesvdDescriptor {
@ -500,7 +665,6 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
u += d.m * d.m;
vt += d.n * d.n;
++info;
}
break;
}
@ -528,10 +692,9 @@ void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
cuComplex* u = static_cast<cuComplex*>(buffers[3]);
cuComplex* vt = static_cast<cuComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
ThrowIfErrorStatus(
cusolverDnCgesvd(handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s,
u, d.m, vt, d.n, static_cast<cuComplex*>(work),
d.lwork, /*rwork=*/nullptr, info));
ThrowIfErrorStatus(cusolverDnCgesvd(
handle.get(), d.jobu, d.jobvt, d.m, d.n, a, d.m, s, u, d.m, vt, d.n,
static_cast<cuComplex*>(work), d.lwork, /*rwork=*/nullptr, info));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
@ -570,6 +733,7 @@ py::dict Registrations() {
py::dict dict;
dict["cusolver_getrf"] = EncapsulateFunction(Getrf);
dict["cusolver_syevd"] = EncapsulateFunction(Syevd);
dict["cusolver_syevj"] = EncapsulateFunction(Syevj);
dict["cusolver_gesvd"] = EncapsulateFunction(Gesvd);
return dict;
}
@ -578,6 +742,7 @@ PYBIND11_MODULE(cusolver_kernels, m) {
m.def("registrations", &Registrations);
m.def("build_getrf_descriptor", &BuildGetrfDescriptor);
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
}

View File

@ -97,12 +97,18 @@ def syevd(c, a, lower=False):
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
lwork, opaque = cusolver_kernels.build_syevd_descriptor(
np.dtype(dtype), lower, b, n)
if n <= 32:
kernel = b"cusolver_syevj"
lwork, opaque = cusolver_kernels.build_syevj_descriptor(
np.dtype(dtype), lower, b, n)
else:
kernel = b"cusolver_syevd"
lwork, opaque = cusolver_kernels.build_syevd_descriptor(
np.dtype(dtype), lower, b, n)
eigvals_type = _real_type(dtype)
out = c.CustomCall(
b"cusolver_syevd",
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, dims, layout),