mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX:GPU] Implement the full_matrices=False case of SVD without generating the full matrices and then slicing.
PiperOrigin-RevId: 432425681
This commit is contained in:
parent
f6a5f0dca2
commit
7d02949d24
@ -379,7 +379,7 @@ std::pair<int, py::bytes> BuildGesvdDescriptor(const py::dtype& dtype, int b,
|
||||
// Returns the workspace size and a descriptor for a gesvdj operation.
|
||||
std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
int batch, int m, int n,
|
||||
bool compute_uv) {
|
||||
bool compute_uv, int econ) {
|
||||
CusolverType type = DtypeToCusolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
@ -395,28 +395,28 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
switch (type) {
|
||||
case CusolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize(
|
||||
handle.get(), jobz, /*econ=*/0, m, n,
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj_bufferSize(
|
||||
handle.get(), jobz, /*econ=*/0, m, n,
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj_bufferSize(
|
||||
handle.get(), jobz, /*econ=*/0, m, n,
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params)));
|
||||
break;
|
||||
case CusolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj_bufferSize(
|
||||
handle.get(), jobz, /*econ=*/0, m, n,
|
||||
handle.get(), jobz, econ, m, n,
|
||||
/*A=*/nullptr, /*lda=*/m, /*S=*/nullptr,
|
||||
/*U=*/nullptr, /*ldu=*/m, /*V=*/nullptr,
|
||||
/*ldv=*/n, &lwork, params)));
|
||||
@ -454,8 +454,8 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
break;
|
||||
}
|
||||
}
|
||||
return {lwork,
|
||||
PackDescriptor(GesvdjDescriptor{type, batch, m, n, lwork, jobz})};
|
||||
return {lwork, PackDescriptor(
|
||||
GesvdjDescriptor{type, batch, m, n, lwork, jobz, econ})};
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
|
@ -312,11 +312,14 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
singular_vals_dtype = np.dtype(_real_type(dtype))
|
||||
|
||||
if m < 32 and n < 32:
|
||||
# The batched kernel doesn't support "econ" mode.
|
||||
econ = not full_matrices and b == 1
|
||||
lwork, opaque = _cusolver.build_gesvdj_descriptor(
|
||||
np.dtype(dtype), b, m, n, compute_uv)
|
||||
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + scalar_layout
|
||||
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
|
||||
k = min(m, n)
|
||||
out = _ops.CustomCallWithLayout(
|
||||
c, b"cusolver_gesvdj",
|
||||
operands=(a,),
|
||||
@ -324,8 +327,8 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
||||
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
|
||||
vector_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (m, k if econ else m), matrix_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (n, k if econ else n), matrix_layout),
|
||||
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
|
||||
_Shape.array_shape(dtype, (lwork,), (0,)),
|
||||
)),
|
||||
@ -342,12 +345,18 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
vt = _ops.Transpose(v, tuple(range(num_bd)) + (num_bd + 1, num_bd))
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
vt = _ops.Conj(vt)
|
||||
if not full_matrices and not econ:
|
||||
u = _ops.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)),
|
||||
(1,) * len(dims))
|
||||
vt = _ops.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n),
|
||||
(1,) * len(dims))
|
||||
elif m < n:
|
||||
lwork, opaque = _cusolver.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + scalar_layout
|
||||
matrix_layout = (num_bd + 1, num_bd) + scalar_layout
|
||||
k = n if full_matrices else m
|
||||
out = _ops.CustomCallWithLayout(
|
||||
c, b"cusolver_gesvd",
|
||||
operands=(a,),
|
||||
@ -355,7 +364,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
||||
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
|
||||
vector_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (k, n), matrix_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
|
||||
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
|
||||
_Shape.array_shape(dtype, (lwork,), (0,)),
|
||||
@ -377,6 +386,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
scalar_layout = tuple(range(num_bd - 1, -1, -1))
|
||||
vector_layout = (num_bd,) + scalar_layout
|
||||
matrix_layout = (num_bd, num_bd + 1) + scalar_layout
|
||||
k = m if full_matrices else n
|
||||
out = _ops.CustomCallWithLayout(
|
||||
c, b"cusolver_gesvd",
|
||||
operands=(a,),
|
||||
@ -384,7 +394,7 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
_Shape.array_shape(dtype, batch_dims + (m, n), matrix_layout),
|
||||
_Shape.array_shape(singular_vals_dtype, batch_dims + (min(m, n),),
|
||||
vector_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (m, m), matrix_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (m, k), matrix_layout),
|
||||
_Shape.array_shape(dtype, batch_dims + (n, n), matrix_layout),
|
||||
_Shape.array_shape(np.dtype(np.int32), batch_dims, scalar_layout),
|
||||
_Shape.array_shape(dtype, (lwork,), (0,)),
|
||||
@ -399,9 +409,4 @@ def gesvd(c, a, full_matrices=True, compute_uv=True):
|
||||
u = _ops.GetTupleElement(out, 2)
|
||||
vt = _ops.GetTupleElement(out, 3)
|
||||
info = _ops.GetTupleElement(out, 4)
|
||||
if not full_matrices:
|
||||
u = _ops.Slice(u, (0,) * len(dims), batch_dims + (m, min(m, n)),
|
||||
(1,) * len(dims))
|
||||
vt = _ops.Slice(vt, (0,) * len(dims), batch_dims + (min(m, n), n),
|
||||
(1,) * len(dims))
|
||||
return s, u, vt, info
|
||||
|
@ -629,6 +629,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
|
||||
cudaMemcpyDeviceToDevice, stream)));
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
void* work = buffers[6];
|
||||
int64_t k = d.jobu == 'A' ? d.m : d.n;
|
||||
switch (d.type) {
|
||||
case CusolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
@ -642,7 +643,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
|
||||
/*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
u += d.m * k;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
@ -660,7 +661,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
|
||||
/*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
u += d.m * k;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
@ -677,7 +678,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
|
||||
static_cast<cuComplex*>(work), d.lwork, /*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
u += d.m * k;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
@ -695,7 +696,7 @@ static absl::Status Gesvd_(cudaStream_t stream, void** buffers,
|
||||
/*rwork=*/nullptr, info)));
|
||||
a += d.m * d.n;
|
||||
s += std::min(d.m, d.n);
|
||||
u += d.m * d.m;
|
||||
u += d.m * k;
|
||||
vt += d.n * d.n;
|
||||
++info;
|
||||
}
|
||||
@ -743,7 +744,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
|
||||
float* u = static_cast<float*>(buffers[3]);
|
||||
float* v = static_cast<float*>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj(
|
||||
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
d.n, static_cast<float*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
@ -753,7 +754,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
|
||||
double* u = static_cast<double*>(buffers[3]);
|
||||
double* v = static_cast<double*>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj(
|
||||
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
d.n, static_cast<double*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
@ -763,7 +764,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
|
||||
cuComplex* u = static_cast<cuComplex*>(buffers[3]);
|
||||
cuComplex* v = static_cast<cuComplex*>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj(
|
||||
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
d.n, static_cast<cuComplex*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
@ -773,7 +774,7 @@ static absl::Status Gesvdj_(cudaStream_t stream, void** buffers,
|
||||
cuDoubleComplex* u = static_cast<cuDoubleComplex*>(buffers[3]);
|
||||
cuDoubleComplex* v = static_cast<cuDoubleComplex*>(buffers[4]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj(
|
||||
handle.get(), d.jobz, /*econ=*/0, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v,
|
||||
d.n, static_cast<cuDoubleComplex*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
|
@ -124,6 +124,7 @@ struct GesvdjDescriptor {
|
||||
int batch, m, n;
|
||||
int lwork;
|
||||
cusolverEigMode_t jobz;
|
||||
int econ;
|
||||
};
|
||||
|
||||
void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
|
||||
|
Loading…
x
Reference in New Issue
Block a user