[JAX:GPU] Implement the full_matrices=False case of SVD without generating the full matrices and then slicing.

PiperOrigin-RevId: 432425681
This commit is contained in:
Peter Hawkins 2022-03-04 05:55:08 -08:00 committed by jax authors
parent f6a5f0dca2
commit 7d02949d24
4 changed files with 32 additions and 25 deletions

View File

@ -379,7 +379,7 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
// Returns the workspace size and a descriptor for a gesvdj operation.
std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
int batch, int m, int n,
bool compute_uv) {
bool compute_uv, int econ) {
CusolverType type = DtypeToCusolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
@ -395,28 +395,28 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
switch (type) {
case CusolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize(
handle.get(), jobz, /*econ=*/0, m, n,
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params)));
break;
case CusolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize(
handle.get(), jobz, /*econ=*/0, m, n,
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params)));
break;
case CusolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize(
handle.get(), jobz, /*econ=*/0, m, n,
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params)));
break;
case CusolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize(
handle.get(), jobz, /*econ=*/0, m, n,
handle.get(), jobz, econ, m, n,
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
/*ldv=*/n, &lwork, params)));
@ -454,8 +454,8 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
break;
}
}
return {lwork,
PackDescriptor(GesvdjDescriptor{type, batch, m, n, lwork, jobz})};
return {lwork, PackDescriptor(
GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})};
}
py::dict Registrations() {

View File

@ -312,11 +312,14 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
singular_vals_dtype = np.dtype(_real_type(dtype))
if m < 32 and n < 32:
# The batched kernel doesn't support "econ" mode.
econ = not full_matrices and b == 1
lwork, opaque = _cusolver.build_gesvdj_descriptor(
np.dtype(dtype), b, m, n, compute_uv)
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
k = min(m, n)
out = _ops.CustomCallWithLayout(
c, b"cusolver_gesvdj",
operands=(a,),
@ -324,8 +327,8 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
vector_layout),
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (m, k if econ else m), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (n, k if econ else n), matrix_layout),
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
_Shape.array_shape(dtype, (lwork,), (0,)),
)),
@ -342,12 +345,18 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
vt = _ops.Transpose(v, tuple(range(num_bd)) + (num_bd + 1, num_bd))
if np.issubdtype(dtype, np.complexfloating):
vt = _ops.Conj(vt)
if not full_matrices and not econ:
u = _ops.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)),
(1,) * len(dims))
vt = _ops.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n),
(1,) * len(dims))
elif m < n:
lwork, opaque = _cusolver.build_gesvd_descriptor(
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
k = n if full_matrices else m
out = _ops.CustomCallWithLayout(
c, b"cusolver_gesvd",
operands=(a,),
@ -355,7 +364,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
vector_layout),
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (k, n), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
_Shape.array_shape(dtype, (lwork,), (0,)),
@ -377,6 +386,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + scalar_layout
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
k = m if full_matrices else n
out = _ops.CustomCallWithLayout(
c, b"cusolver_gesvd",
operands=(a,),
@ -384,7 +394,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
vector_layout),
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (m, k), matrix_layout),
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
_Shape.array_shape(dtype, (lwork,), (0,)),
@ -399,9 +409,4 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
u = _ops.GetTupleElement(out, 2)
vt = _ops.GetTupleElement(out, 3)
info = _ops.GetTupleElement(out, 4)
if not full_matrices:
u = _ops.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)),
(1,) * len(dims))
vt = _ops.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n),
(1,) * len(dims))
return s, u, vt, info

View File

@ -629,6 +629,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
cudaMemcpyDeviceToDevice, stream)));
int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6];
int64_t k = d.jobu == 'A' ? d.m : d.n;
switch (d.type) {
case CusolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
@ -642,7 +643,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
u += d.m * k;
vt += d.n * d.n;
++info;
}
@ -660,7 +661,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
u += d.m * k;
vt += d.n * d.n;
++info;
}
@ -677,7 +678,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
static_cast<cuComplex*>(work), d.lwork, /*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
u += d.m * k;
vt += d.n * d.n;
++info;
}
@ -695,7 +696,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
/*rwork=*/nullptr, info)));
a += d.m * d.n;
s += std::min(d.m, d.n);
u += d.m * d.m;
u += d.m * k;
vt += d.n * d.n;
++info;
}
@ -743,7 +744,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
float* u = static_cast<float*>(buffers[3]);
float* v = static_cast<float*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj(
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
d.n, static_cast<float*>(work), d.lwork, info, params)));
break;
}
@ -753,7 +754,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
double* u = static_cast<double*>(buffers[3]);
double* v = static_cast<double*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj(
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
d.n, static_cast<double*>(work), d.lwork, info, params)));
break;
}
@ -763,7 +764,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
cuComplex* u = static_cast<cuComplex*>(buffers[3]);
cuComplex* v = static_cast<cuComplex*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj(
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
d.n, static_cast<cuComplex*>(work), d.lwork, info, params)));
break;
}
@ -773,7 +774,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
cuDoubleComplex* u = static_cast<cuDoubleComplex*>(buffers[3]);
cuDoubleComplex* v = static_cast<cuDoubleComplex*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj(
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
d.n, static_cast<cuDoubleComplex*>(work), d.lwork, info, params)));
break;
}

View File

@ -124,6 +124,7 @@ struct GesvdjDescriptor {
int batch, m, n;
int lwork;
cusolverEigMode_t jobz;
int econ;
};
void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,