mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[ROCm] Upgrade to ROCm 5.3 and associated enhancements
This commit is contained in:
parent
849f837b6a
commit
b815ac9d8e
@ -1,10 +1,10 @@
|
||||
FROM ubuntu:focal
|
||||
MAINTAINER Reza Rahimi <reza.rahimi@amd.com>
|
||||
|
||||
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.2/
|
||||
ARG ROCM_DEB_REPO=http://repo.radeon.com/rocm/apt/5.3/
|
||||
ARG ROCM_BUILD_NAME=ubuntu
|
||||
ARG ROCM_BUILD_NUM=main
|
||||
ARG ROCM_PATH=/opt/rocm-5.2.0
|
||||
ARG ROCM_PATH=/opt/rocm-5.3.0
|
||||
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ENV HOME /root/
|
||||
@ -80,7 +80,8 @@ RUN add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt update && \
|
||||
apt install -y python3.9-dev \
|
||||
python3-pip \
|
||||
python3.9-distutils
|
||||
python3.9-distutils \
|
||||
python-is-python3
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
|
||||
|
||||
|
@ -106,6 +106,7 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so"
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p",
|
||||
"--parallel",
|
||||
|
@ -376,7 +376,7 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
return out[:3]
|
||||
|
||||
cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True)
|
||||
rocm_syevd = partial(_syevd_mhlo, "hip", _hipsolver, False)
|
||||
rocm_syevd = partial(_syevd_mhlo, "hip", _hipsolver, True)
|
||||
|
||||
|
||||
def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
|
@ -293,6 +293,74 @@ std::pair<int, py::bytes> BuildSyevdDescriptor(const py::dtype& dtype,
|
||||
return {lwork, PackDescriptor(SyevdDescriptor{type, uplo, b, n, lwork})};
|
||||
}
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
// Returns the workspace size and a descriptor for a syevj_batched operation.
|
||||
std::pair<int, py::bytes> BuildSyevjDescriptor(const py::dtype& dtype,
|
||||
bool lower, int batch, int n) {
|
||||
HipsolverType type = DtypeToHipsolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
hipsolverSyevjInfo_t params;
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
|
||||
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
|
||||
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
|
||||
hipsolverFillMode_t uplo =
|
||||
lower ? HIPSOLVER_FILL_MODE_LOWER : HIPSOLVER_FILL_MODE_UPPER;
|
||||
if (batch == 1) {
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params)));
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
switch (type) {
|
||||
case HipsolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case HipsolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case HipsolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
case HipsolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(hipsolverZheevjBatched_bufferSize(
|
||||
handle.get(), jobz, uplo, n, /*A=*/nullptr, /*lda=*/n,
|
||||
/*W=*/nullptr, &lwork, params, batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
return {lwork, PackDescriptor(SyevjDescriptor{type, uplo, batch, n, lwork})};
|
||||
}
|
||||
|
||||
// Singular value decomposition using QR algorithm: gesvd
|
||||
|
||||
// Returns the workspace size and a descriptor for a gesvd operation.
|
||||
@ -343,8 +411,7 @@ py::dict Registrations() {
|
||||
dict["hipsolver_geqrf"] = EncapsulateFunction(Geqrf);
|
||||
dict["hipsolver_orgqr"] = EncapsulateFunction(Orgqr);
|
||||
dict["hipsolver_syevd"] = EncapsulateFunction(Syevd);
|
||||
// dict["cusolver_syevj"] = EncapsulateFunction(Syevj); not supported by
|
||||
// ROCm yet
|
||||
dict["hipsolver_syevj"] = EncapsulateFunction(Syevj);
|
||||
dict["hipsolver_gesvd"] = EncapsulateFunction(Gesvd);
|
||||
// dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); not supported by
|
||||
// ROCm yet
|
||||
@ -358,8 +425,7 @@ PYBIND11_MODULE(_hipsolver, m) {
|
||||
m.def("build_geqrf_descriptor", &BuildGeqrfDescriptor);
|
||||
m.def("build_orgqr_descriptor", &BuildOrgqrDescriptor);
|
||||
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
|
||||
// m.def("build_syevj_descriptor", &BuildSyevjDescriptor); not supported by
|
||||
// ROCm yet
|
||||
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
|
||||
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
|
||||
// m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor); not supported by
|
||||
// ROCm yet
|
||||
|
@ -145,7 +145,6 @@ static absl::Status Potrf_(hipStream_t stream, void** buffers,
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZpotrfBatched(
|
||||
handle.get(), d.uplo, d.n, static_cast<hipDoubleComplex**>(workspace), d.n,
|
||||
reinterpret_cast<hipDoubleComplex*>(static_cast<hipDoubleComplex**>(workspace) +
|
||||
@ -268,9 +267,6 @@ static absl::Status Geqrf_(hipStream_t stream, void** buffers,
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
// TODO(rocm): workaround for unset devinfo. See SWDEV-317485
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(hipMemsetAsync(info, 0, sizeof(int) * d.batch, stream)));
|
||||
|
||||
void* workspace = buffers[4];
|
||||
switch (d.type) {
|
||||
@ -358,9 +354,6 @@ static absl::Status Orgqr_(hipStream_t stream, void** buffers,
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
// TODO(rocm): workaround for unset devinfo. See SWDEV-317485
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(hipMemsetAsync(info, 0, sizeof(int) * d.batch, stream)));
|
||||
|
||||
void* workspace = buffers[4];
|
||||
switch (d.type) {
|
||||
@ -513,7 +506,115 @@ void Syevd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(rocm): add Syevj_ apis when support from hipsolver is ready
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
absl::Status Syevj_(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SyevjDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SyevjDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfHipsolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
||||
hipMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
hipsolverSyevjInfo_t params;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCreateSyevjInfo(¶ms)));
|
||||
std::unique_ptr<void, void (*)(hipsolverSyevjInfo_t)> params_cleanup(
|
||||
params, [](hipsolverSyevjInfo_t p) { hipsolverDestroySyevjInfo(p); });
|
||||
hipsolverEigMode_t jobz = HIPSOLVER_EIG_MODE_VECTOR;
|
||||
int* info = static_cast<int*>(buffers[3]);
|
||||
void* work = buffers[4];
|
||||
if (d.batch == 1) {
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<float*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<double*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipFloatComplex*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverZheevj(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipDoubleComplex*>(work), d.lwork, info, params)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (d.type) {
|
||||
case HipsolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverSsyevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<float*>(work), d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverDsyevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<double*>(work), d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C64: {
|
||||
hipFloatComplex* a = static_cast<hipFloatComplex*>(buffers[1]);
|
||||
float* w = static_cast<float*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(hipsolverCheevjBatched(
|
||||
handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipFloatComplex*>(work), d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
case HipsolverType::C128: {
|
||||
hipDoubleComplex* a = static_cast<hipDoubleComplex*>(buffers[1]);
|
||||
double* w = static_cast<double*>(buffers[2]);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
hipsolverZheevjBatched(handle.get(), jobz, d.uplo, d.n, a, d.n, w,
|
||||
static_cast<hipDoubleComplex*>(work),
|
||||
d.lwork, info, params, d.batch)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Syevj_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
// Singular value decomposition using QR algorithm: gesvd
|
||||
|
||||
static absl::Status Gesvd_(hipStream_t stream, void** buffers,
|
||||
|
@ -92,6 +92,19 @@ struct SyevdDescriptor {
|
||||
void Syevd(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
|
||||
// Supports batches of matrices up to size 32.
|
||||
|
||||
struct SyevjDescriptor {
|
||||
HipsolverType type;
|
||||
hipsolverFillMode_t uplo;
|
||||
int batch, n;
|
||||
int lwork;
|
||||
};
|
||||
|
||||
void Syevj(hipStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// Singular value decomposition using QR algorithm: gesvd
|
||||
|
||||
struct GesvdDescriptor {
|
||||
|
@ -384,7 +384,6 @@ class F32LobpcgTest(LobpcgTest):
|
||||
self.checkLobpcgConsistency(matrix_name, n, k, m, tol, jnp.float32)
|
||||
|
||||
@parameterized.named_parameters(_make_concrete_cases(f64=False))
|
||||
@jtu.skip_on_devices("rocm") # see SWDEV-321073
|
||||
def testLobpcgMonotonicityF32(self, matrix_name, n, k, m, tol):
|
||||
self.checkLobpcgMonotonicity(matrix_name, n, k, m, tol, jnp.float32)
|
||||
|
||||
|
@ -453,7 +453,6 @@ class cuSparseTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
@jtu.skip_on_devices("rocm") # TODO(rocm): see SWDEV-328107
|
||||
def test_coo_sorted_indices_gpu_lowerings(self):
|
||||
dtype = jnp.float32
|
||||
|
||||
@ -1143,7 +1142,6 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
[(5, 3), (5, 2), [0], [0]],
|
||||
]
|
||||
for dtype in jtu.dtypes.floating + jtu.dtypes.complex))
|
||||
@jtu.skip_on_devices("rocm")
|
||||
def test_bcoo_dot_general_cusparse(
|
||||
self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting):
|
||||
rng = jtu.rand_small(self.rng())
|
||||
@ -1275,7 +1273,6 @@ class BCOOTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(not GPU_LOWERING_ENABLED, "test requires cusparse/hipsparse")
|
||||
@unittest.skipIf(jtu.device_under_test() != "gpu", "test requires GPU")
|
||||
@jtu.skip_on_devices("rocm") # TODO(rocm): see SWDEV-328107
|
||||
def test_bcoo_dot_general_oob_and_unsorted_indices_cusparse(self):
|
||||
"""Tests bcoo dot general with out-of-bound and unsorted indices."""
|
||||
|
||||
|
@ -208,6 +208,7 @@ class SvdTest(jtu.JaxTestCase):
|
||||
for m, n, r, c in zip([2, 4, 8], [4, 4, 6], [1, 0, 1], [1, 0, 1])
|
||||
for dtype in jtu.dtypes.floating
|
||||
])
|
||||
@jtu.skip_on_devices("rocm")
|
||||
def testSvdOnTinyElement(self, m, n, r, c, dtype):
|
||||
"""Tests SVD on matrix of zeros and close-to-zero entries."""
|
||||
a = jnp.zeros((m, n), dtype=dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user