mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.
If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd(). PiperOrigin-RevId: 582279549
This commit is contained in:
parent
ef9075159a
commit
95e2d3fc2b
@ -355,7 +355,7 @@ std::pair<int, nb::bytes> BuildGesvdjDescriptor(const dtype& dtype, int batch,
|
|||||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms)));
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms)));
|
||||||
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
|
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
|
||||||
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
|
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
|
||||||
if (batch == 1) {
|
if (batch <= 1 || m > 32 || n > 32 || econ) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case SolverType::F32:
|
case SolverType::F32:
|
||||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize(
|
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize(
|
||||||
|
@ -463,15 +463,14 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers,
|
|||||||
// the batch is passed as a second operand
|
// the batch is passed as a second operand
|
||||||
gpuMemcpyAsync((void*)&batch,
|
gpuMemcpyAsync((void*)&batch,
|
||||||
reinterpret_cast<const std::int64_t*>(buffers[1]),
|
reinterpret_cast<const std::int64_t*>(buffers[1]),
|
||||||
sizeof(batch), gpuMemcpyDeviceToHost,
|
sizeof(batch), gpuMemcpyDeviceToHost, stream);
|
||||||
stream);
|
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
|
||||||
output_idx = 2;
|
output_idx = 2;
|
||||||
}
|
}
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||||
buffers[output_idx], buffers[0],
|
buffers[output_idx], buffers[0],
|
||||||
SizeOfSolverType(d.type) * batch *
|
SizeOfSolverType(d.type) * batch * static_cast<std::int64_t>(d.n) *
|
||||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n),
|
static_cast<std::int64_t>(d.n),
|
||||||
gpuMemcpyDeviceToDevice, stream)));
|
gpuMemcpyDeviceToDevice, stream)));
|
||||||
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
|
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
|
||||||
int* info = static_cast<int*>(buffers[output_idx + 2]);
|
int* info = static_cast<int*>(buffers[output_idx + 2]);
|
||||||
@ -662,11 +661,13 @@ static absl::Status Gesvd_(gpuStream_t stream, void** buffers,
|
|||||||
auto h = SolverHandlePool::Borrow(stream);
|
auto h = SolverHandlePool::Borrow(stream);
|
||||||
JAX_RETURN_IF_ERROR(h.status());
|
JAX_RETURN_IF_ERROR(h.status());
|
||||||
auto& handle = *h;
|
auto& handle = *h;
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
if (buffers[1] != buffers[0]) {
|
||||||
buffers[1], buffers[0],
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||||
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
buffers[1], buffers[0],
|
||||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||||
gpuMemcpyDeviceToDevice, stream)));
|
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||||
|
gpuMemcpyDeviceToDevice, stream)));
|
||||||
|
}
|
||||||
int* info = static_cast<int*>(buffers[5]);
|
int* info = static_cast<int*>(buffers[5]);
|
||||||
void* work = buffers[6];
|
void* work = buffers[6];
|
||||||
int64_t k = d.jobu == 'A' ? d.m : d.n;
|
int64_t k = d.jobu == 'A' ? d.m : d.n;
|
||||||
@ -767,27 +768,37 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
|
|||||||
auto h = SolverHandlePool::Borrow(stream);
|
auto h = SolverHandlePool::Borrow(stream);
|
||||||
JAX_RETURN_IF_ERROR(h.status());
|
JAX_RETURN_IF_ERROR(h.status());
|
||||||
auto& handle = *h;
|
auto& handle = *h;
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
if (buffers[1] != buffers[0]) {
|
||||||
buffers[1], buffers[0],
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||||
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
buffers[1], buffers[0],
|
||||||
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||||
gpuMemcpyDeviceToDevice, stream)));
|
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
|
||||||
|
gpuMemcpyDeviceToDevice, stream)));
|
||||||
|
}
|
||||||
int* info = static_cast<int*>(buffers[5]);
|
int* info = static_cast<int*>(buffers[5]);
|
||||||
void* work = buffers[6];
|
void* work = buffers[6];
|
||||||
gesvdjInfo_t params;
|
gesvdjInfo_t params;
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms)));
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(¶ms)));
|
||||||
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
|
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
|
||||||
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
|
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
|
||||||
if (d.batch == 1) {
|
if (d.batch <= 1 || d.m > 32 || d.n > 32 || d.econ) {
|
||||||
|
int k = std::min(d.m, d.n);
|
||||||
switch (d.type) {
|
switch (d.type) {
|
||||||
case SolverType::F32: {
|
case SolverType::F32: {
|
||||||
float* a = static_cast<float*>(buffers[1]);
|
float* a = static_cast<float*>(buffers[1]);
|
||||||
float* s = static_cast<float*>(buffers[2]);
|
float* s = static_cast<float*>(buffers[2]);
|
||||||
float* u = static_cast<float*>(buffers[3]);
|
float* u = static_cast<float*>(buffers[3]);
|
||||||
float* v = static_cast<float*>(buffers[4]);
|
float* v = static_cast<float*>(buffers[4]);
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj(
|
for (int i = 0; i < d.batch; ++i) {
|
||||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj(
|
||||||
static_cast<float*>(work), d.lwork, info, params)));
|
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
||||||
|
static_cast<float*>(work), d.lwork, info, params)));
|
||||||
|
a += d.m * d.n;
|
||||||
|
s += k;
|
||||||
|
u += d.m * (d.econ ? k : d.m);
|
||||||
|
v += (d.econ ? k : d.n) * d.n;
|
||||||
|
++info;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case SolverType::F64: {
|
case SolverType::F64: {
|
||||||
@ -795,9 +806,16 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
|
|||||||
double* s = static_cast<double*>(buffers[2]);
|
double* s = static_cast<double*>(buffers[2]);
|
||||||
double* u = static_cast<double*>(buffers[3]);
|
double* u = static_cast<double*>(buffers[3]);
|
||||||
double* v = static_cast<double*>(buffers[4]);
|
double* v = static_cast<double*>(buffers[4]);
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj(
|
for (int i = 0; i < d.batch; ++i) {
|
||||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj(
|
||||||
static_cast<double*>(work), d.lwork, info, params)));
|
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
||||||
|
static_cast<double*>(work), d.lwork, info, params)));
|
||||||
|
a += d.m * d.n;
|
||||||
|
s += k;
|
||||||
|
u += d.m * (d.econ ? k : d.m);
|
||||||
|
v += (d.econ ? k : d.n) * d.n;
|
||||||
|
++info;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case SolverType::C64: {
|
case SolverType::C64: {
|
||||||
@ -805,9 +823,16 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
|
|||||||
float* s = static_cast<float*>(buffers[2]);
|
float* s = static_cast<float*>(buffers[2]);
|
||||||
gpuComplex* u = static_cast<gpuComplex*>(buffers[3]);
|
gpuComplex* u = static_cast<gpuComplex*>(buffers[3]);
|
||||||
gpuComplex* v = static_cast<gpuComplex*>(buffers[4]);
|
gpuComplex* v = static_cast<gpuComplex*>(buffers[4]);
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj(
|
for (int i = 0; i < d.batch; ++i) {
|
||||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj(
|
||||||
static_cast<gpuComplex*>(work), d.lwork, info, params)));
|
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
||||||
|
static_cast<gpuComplex*>(work), d.lwork, info, params)));
|
||||||
|
a += d.m * d.n;
|
||||||
|
s += k;
|
||||||
|
u += d.m * (d.econ ? k : d.m);
|
||||||
|
v += (d.econ ? k : d.n) * d.n;
|
||||||
|
++info;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case SolverType::C128: {
|
case SolverType::C128: {
|
||||||
@ -815,9 +840,16 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
|
|||||||
double* s = static_cast<double*>(buffers[2]);
|
double* s = static_cast<double*>(buffers[2]);
|
||||||
gpuDoubleComplex* u = static_cast<gpuDoubleComplex*>(buffers[3]);
|
gpuDoubleComplex* u = static_cast<gpuDoubleComplex*>(buffers[3]);
|
||||||
gpuDoubleComplex* v = static_cast<gpuDoubleComplex*>(buffers[4]);
|
gpuDoubleComplex* v = static_cast<gpuDoubleComplex*>(buffers[4]);
|
||||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj(
|
for (int i = 0; i < d.batch; ++i) {
|
||||||
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj(
|
||||||
static_cast<gpuDoubleComplex*>(work), d.lwork, info, params)));
|
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
|
||||||
|
static_cast<gpuDoubleComplex*>(work), d.lwork, info, params)));
|
||||||
|
a += d.m * d.n;
|
||||||
|
s += k;
|
||||||
|
u += d.m * (d.econ ? k : d.m);
|
||||||
|
v += (d.econ ? k : d.n) * d.n;
|
||||||
|
++info;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,15 +22,15 @@ limitations under the License.
|
|||||||
|
|
||||||
#if defined(JAX_GPU_CUDA)
|
#if defined(JAX_GPU_CUDA)
|
||||||
|
|
||||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cuda/include/cublas_v2.h"
|
#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cuda/include/cuda.h"
|
#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cuda/include/cufft.h"
|
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
|
#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
|
||||||
#include "third_party/gpus/cudnn/cudnn.h"
|
#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
|
||||||
|
|
||||||
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
|
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
|
||||||
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
|
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
|
||||||
@ -74,7 +74,7 @@ typedef CUevent gpuEvent_t;
|
|||||||
typedef CUfunction gpuFunction_t;
|
typedef CUfunction gpuFunction_t;
|
||||||
typedef cudnnHandle_t gpudnnHandle_t;
|
typedef cudnnHandle_t gpudnnHandle_t;
|
||||||
typedef cudnnStatus_t gpudnnStatus_t;
|
typedef cudnnStatus_t gpudnnStatus_t;
|
||||||
typedef CUmodule gpuModule_t;
|
typedef CUmodule gpuModule_t;
|
||||||
typedef cusolverDnHandle_t gpusolverDnHandle_t;
|
typedef cusolverDnHandle_t gpusolverDnHandle_t;
|
||||||
typedef cusolverStatus_t gpusolverStatus_t;
|
typedef cusolverStatus_t gpusolverStatus_t;
|
||||||
typedef cusolverEigMode_t gpusolverEigMode_t;
|
typedef cusolverEigMode_t gpusolverEigMode_t;
|
||||||
@ -266,19 +266,24 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
|||||||
#define gpuInit cuInit
|
#define gpuInit cuInit
|
||||||
#define gpuLaunchKernel cuLaunchKernel
|
#define gpuLaunchKernel cuLaunchKernel
|
||||||
#define gpuMemcpyDtoHAsync cuMemcpyDtoHAsync
|
#define gpuMemcpyDtoHAsync cuMemcpyDtoHAsync
|
||||||
#define gpuMemcpyHtoDAsync cuMemcpyHtoDAsync
|
#define gpuMemcpyHtoDAsync cuMemcpyHtoDAsync
|
||||||
#define gpuMemsetD8Async cuMemsetD8Async
|
#define gpuMemsetD8Async cuMemsetD8Async
|
||||||
#define gpuModuleLoadData cuModuleLoadData
|
#define gpuModuleLoadData cuModuleLoadData
|
||||||
#define gpuModuleGetFunction cuModuleGetFunction
|
#define gpuModuleGetFunction cuModuleGetFunction
|
||||||
#define gpuModuleUnload cuModuleUnload
|
#define gpuModuleUnload cuModuleUnload
|
||||||
#define gpuStreamGetCtx cuStreamGetCtx
|
#define gpuStreamGetCtx cuStreamGetCtx
|
||||||
|
|
||||||
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR
|
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
|
||||||
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR
|
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR
|
||||||
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN
|
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR \
|
||||||
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR
|
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR
|
||||||
|
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN \
|
||||||
|
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN
|
||||||
|
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR \
|
||||||
|
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR
|
||||||
#define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
|
#define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
|
||||||
#define GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
|
#define GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES \
|
||||||
|
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
|
||||||
#define GPU_EVENT_DEFAULT CU_EVENT_DEFAULT
|
#define GPU_EVENT_DEFAULT CU_EVENT_DEFAULT
|
||||||
|
|
||||||
#define gpuGetLastError cudaGetLastError
|
#define gpuGetLastError cudaGetLastError
|
||||||
@ -293,9 +298,9 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
|||||||
|
|
||||||
namespace jax::JAX_GPU_NAMESPACE {
|
namespace jax::JAX_GPU_NAMESPACE {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr uint32_t kNumThreadsPerWarp = 32;
|
constexpr uint32_t kNumThreadsPerWarp = 32;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} // namespace jax::JAX_GPU_NAMESPACE
|
||||||
|
|
||||||
#elif defined(JAX_GPU_HIP)
|
#elif defined(JAX_GPU_HIP)
|
||||||
|
|
||||||
@ -483,7 +488,6 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
|
|||||||
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
|
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
|
||||||
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
|
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
|
||||||
|
|
||||||
|
|
||||||
#define gpuGetLastError hipGetLastError
|
#define gpuGetLastError hipGetLastError
|
||||||
#define gpuGetErrorString hipGetErrorString
|
#define gpuGetErrorString hipGetErrorString
|
||||||
#define gpuMemcpyAsync hipMemcpyAsync
|
#define gpuMemcpyAsync hipMemcpyAsync
|
||||||
@ -514,21 +518,26 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
|
|||||||
#define gpuModuleUnload hipModuleUnload
|
#define gpuModuleUnload hipModuleUnload
|
||||||
#define gpuMemsetD8Async hipMemsetD8Async
|
#define gpuMemsetD8Async hipMemsetD8Async
|
||||||
#define gpuMemcpyDtoHAsync hipMemcpyDtoHAsync
|
#define gpuMemcpyDtoHAsync hipMemcpyDtoHAsync
|
||||||
#define gpuMemcpyHtoDAsync hipMemcpyHtoDAsync
|
#define gpuMemcpyHtoDAsync hipMemcpyHtoDAsync
|
||||||
#define gpuMemsetD8Async hipMemsetD8Async
|
#define gpuMemsetD8Async hipMemsetD8Async
|
||||||
|
|
||||||
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR hipDeviceAttributeComputeCapabilityMajor
|
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
|
||||||
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR hipDeviceAttributeComputeCapabilityMinor
|
hipDeviceAttributeComputeCapabilityMajor
|
||||||
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN hipDeviceAttributeMaxSharedMemoryPerBlock
|
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR \
|
||||||
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR hipDeviceAttributeMaxBlocksPerMultiProcessor
|
hipDeviceAttributeComputeCapabilityMinor
|
||||||
#define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
|
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN \
|
||||||
|
hipDeviceAttributeMaxSharedMemoryPerBlock
|
||||||
|
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR \
|
||||||
|
hipDeviceAttributeMaxBlocksPerMultiProcessor
|
||||||
|
#define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES \
|
||||||
|
HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
|
||||||
#define GPU_EVENT_DEFAULT hipEventDefault
|
#define GPU_EVENT_DEFAULT hipEventDefault
|
||||||
|
|
||||||
namespace jax::JAX_GPU_NAMESPACE {
|
namespace jax::JAX_GPU_NAMESPACE {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr uint32_t kNumThreadsPerWarp = 64;
|
constexpr uint32_t kNumThreadsPerWarp = 64;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} // namespace jax::JAX_GPU_NAMESPACE
|
||||||
|
|
||||||
#else // defined(GPU vendor)
|
#else // defined(GPU vendor)
|
||||||
#error "Either JAX_GPU_CUDA or JAX_GPU_HIP must be defined"
|
#error "Either JAX_GPU_CUDA or JAX_GPU_HIP must be defined"
|
||||||
|
@ -374,11 +374,10 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
|||||||
# outperform gesvd for small-moderate matrices, e.g., see:
|
# outperform gesvd for small-moderate matrices, e.g., see:
|
||||||
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
|
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
|
||||||
# slide 5.
|
# slide 5.
|
||||||
if have_jacobi_solver and (
|
if have_jacobi_solver and m <= 1024 and n <= 1024:
|
||||||
(b == 1 and m <= 1024 and n <= 1024) or (m <= 32 and n <= 32)
|
# The gesvdjbatched kernel doesn't support "econ" mode. We will use that
|
||||||
):
|
# kernel only if b > 1 and m <= 32 and n <= 32.
|
||||||
# The batched kernel doesn't support "econ" mode.
|
econ = not full_matrices and (b <= 1 or m > 32 or n > 32)
|
||||||
econ = not full_matrices and b == 1
|
|
||||||
lwork, opaque = gpu_solver.build_gesvdj_descriptor(
|
lwork, opaque = gpu_solver.build_gesvdj_descriptor(
|
||||||
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
|
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
|
||||||
k = min(m, n)
|
k = min(m, n)
|
||||||
|
@ -626,7 +626,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
tol = 80 * jnp.finfo(dtype).eps
|
tol = 80 * jnp.finfo(dtype).eps
|
||||||
reconstruction_tol = 2 * tol
|
reconstruction_tol = 2 * tol
|
||||||
unitariness_tol = tol
|
unitariness_tol = 3 * tol
|
||||||
|
|
||||||
a, = args_maker()
|
a, = args_maker()
|
||||||
if hermitian:
|
if hermitian:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user