[ROCm] Upgrade to ROCm 5.3 and associated enhancements

This commit is contained in:
Rohit Santhanam 2022-10-01 03:39:55 +00:00
parent 849f837b6a
commit b815ac9d8e
9 changed files with 199 additions and 20 deletions

View File

@ -1,10 +1,10 @@
FROM ubuntu:focal
MAINTAINER Reza Rahimi <reza.rahimi@amd.com>
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.2/
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.3/
ARG ROCM_BUILD_NAME=ubuntu
ARG ROCM_BUILD_NUM=main
ARG ROCM_PATH=/opt/rocm-5.2.0
ARG ROCM_PATH=/opt/rocm-5.3.0
ARG DEBIAN_FRONTEND=noninteractive
ENV HOME /root/
@ -80,7 +80,8 @@ RUN add-apt-repository ppa:deadsnakes/ppa && \
apt update && \
apt install -y python3.9-dev \
python3-pip \
python3.9-distutils
python3.9-distutils \
python-is-python3
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1

View File

@ -106,6 +106,7 @@ def main(args):
if __name__ == '__main__':
os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so"
parser = argparse.ArgumentParser()
parser.add_argument("-p",
"--parallel",

View File

@ -376,7 +376,7 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
return out[:3]
cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True)
rocm_syevd = partial(_syevd_mhlo, "hip", _hipsolver, False)
rocm_syevd = partial(_syevd_mhlo, "hip", _hipsolver, True)
def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,

View File

@ -293,6 +293,74 @@ std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})};
}
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
// Returns the workspace size and a descriptor for a syevj_batched operation.
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
bool lower, int batch, int n) {
HipsolverType type = DtypeToHipsolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
hipsolverSyevjInfo_t params;
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(&params)));
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
hipsolverFillMode_t uplo =
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
if (batch == 1) {
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params)));
break;
}
} else {
switch (type) {
case HipsolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case HipsolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case HipsolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
case HipsolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevjBatched_bufferSize(
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
/*W=*/nullptr, &lwork, params, batch)));
break;
}
}
return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})};
}
// Singular value decomposition using QR algorithm: gesvd
// Returns the workspace size and a descriptor for a gesvd operation.
@ -343,8 +411,7 @@ py::dict Registrations() {
dict["hipsolver_geqrf"] = EncapsulateFunction(Geqrf);
dict["hipsolver_orgqr"] = EncapsulateFunction(Orgqr);
dict["hipsolver_syevd"] = EncapsulateFunction(Syevd);
// dict["cusolver_syevj"] = EncapsulateFunction(Syevj); not supported by
// ROCm yet
dict["hipsolver_syevj"] = EncapsulateFunction(Syevj);
dict["hipsolver_gesvd"] = EncapsulateFunction(Gesvd);
// dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); not supported by
// ROCm yet
@ -358,8 +425,7 @@ PYBIND11_MODULE(_hipsolver, m) {
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
// m.def("build_syevj_descriptor", &BuildSyevjDescriptor); not supported by
// ROCm yet
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
// m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); not supported by
// ROCm yet

View File

@ -145,7 +145,6 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers,
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched(
handle.get(), d.uplo, d.n, static_cast<hipDoubleComplex**>(workspace), d.n,
reinterpret_cast<hipDoubleComplex*>(static_cast<hipDoubleComplex**>(workspace) +
@ -268,9 +267,6 @@ static absl::Status Geqrf_(hipStream_t stream, void** buffers,
}
int* info = static_cast<int*>(buffers[3]);
// TODO(rocm): workaround for unset devinfo. See SWDEV-317485
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipMemsetAsync(info, 0, sizeof(int) * d.batch, stream)));
void* workspace = buffers[4];
switch (d.type) {
@ -358,9 +354,6 @@ static absl::Status Orgqr_(hipStream_t stream, void** buffers,
}
int* info = static_cast<int*>(buffers[3]);
// TODO(rocm): workaround for unset devinfo. See SWDEV-317485
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(hipMemsetAsync(info, 0, sizeof(int) * d.batch, stream)));
void* workspace = buffers[4];
switch (d.type) {
@ -513,7 +506,115 @@ void Syevd(hipStream_t stream, void** buffers, const char* opaque,
}
}
// TODO(rocm): add Syevj_ apis when support from hipsolver is ready
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len) {
auto s = UnpackDescriptor<SyevjDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SyevjDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
buffers[1], buffers[0],
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
hipMemcpyDeviceToDevice, stream)));
}
hipsolverSyevjInfo_t params;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(&params)));
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[3]);
void* work = buffers[4];
if (d.batch == 1) {
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<float*>(work), d.lwork, info, params)));
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<double*>(work), d.lwork, info, params)));
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipFloatComplex*>(work), d.lwork, info, params)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipDoubleComplex*>(work), d.lwork, info, params)));
break;
}
}
} else {
switch (d.type) {
case HipsolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<float*>(work), d.lwork, info, params, d.batch)));
break;
}
case HipsolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<double*>(work), d.lwork, info, params, d.batch)));
break;
}
case HipsolverType::C64: {
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
float* w = static_cast<float*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched(
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipFloatComplex*>(work), d.lwork, info, params, d.batch)));
break;
}
case HipsolverType::C128: {
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
double* w = static_cast<double*>(buffers[2]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
hipsolverZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
static_cast<hipDoubleComplex*>(work),
d.lwork, info, params, d.batch)));
break;
}
}
}
return absl::OkStatus();
}
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Syevj_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
// Singular value decomposition using QR algorithm: gesvd
static absl::Status Gesvd_(hipStream_t stream, void** buffers,

View File

@ -92,6 +92,19 @@ struct SyevdDescriptor {
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.
struct SyevjDescriptor {
HipsolverType type;
hipsolverFillMode_t uplo;
int batch, n;
int lwork;
};
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// Singular value decomposition using QR algorithm: gesvd
struct GesvdDescriptor {

View File

@ -384,7 +384,6 @@ class F32LobpcgTest(LobpcgTest):
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32)
@parameterized.named_parameters(_make_concrete_cases(f64=False))
@jtu.skip_on_devices("rocm") # see SWDEV-321073
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol):
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)

View File

@ -453,7 +453,6 @@ class cuSparseTest(jtu.JaxTestCase):
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@jtu.skip_on_devices("rocm") # TODO(rocm): see SWDEV-328107
def test_coo_sorted_indices_gpu_lowerings(self):
dtype = jnp.float32
@ -1143,7 +1142,6 @@ class BCOOTest(jtu.JaxTestCase):
[(5, 3), (5, 2), [0], [0]],
]
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
@jtu.skip_on_devices("rocm")
def test_bcoo_dot_general_cusparse(
self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting):
rng = jtu.rand_small(self.rng())
@ -1275,7 +1273,6 @@ class BCOOTest(jtu.JaxTestCase):
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
@jtu.skip_on_devices("rocm") # TODO(rocm): see SWDEV-328107
def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
"""Tests bcoo dot general with out-of-bound and unsorted indices."""

View File

@ -208,6 +208,7 @@ class SvdTest(jtu.JaxTestCase):
for m, n, r, c in zip([2, 4, 8], [4, 4, 6], [1, 0, 1], [1, 0, 1])
for dtype in jtu.dtypes.floating
])
@jtu.skip_on_devices("rocm")
def testSvdOnTinyElement(self, m, n, r, c, dtype):
"""Tests SVD on matrix of zeros and close-to-zero entries."""
a = jnp.zeros((m, n), dtype=dtype)