[JAX:GPU] Generalize gesvdj kernel to iterate over the unbatched Jacobi kernel in cases that we cannot use the batched kernel.

If the gesvdj() is preferable to gesvd() absent a batch dimension, even if there is a batch dimension we should prefer a loop of gesvdj() over a loop of gesvd().

PiperOrigin-RevId: 582279549
This commit is contained in:
Peter Hawkins 2023-11-14 04:51:45 -08:00 committed by jax authors
parent ef9075159a
commit 95e2d3fc2b
5 changed files with 101 additions and 61 deletions

View File

@ -355,7 +355,7 @@ std::pair<int, nb::bytes> BuildGesvdjDescriptor(const dtype& dtype, int batch,
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(&params))); JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(&params)));
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup( std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
if (batch == 1) { if (batch <= 1 || m > 32 || n > 32 || econ) {
switch (type) { switch (type) {
case SolverType::F32: case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize( JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj_bufferSize(

View File

@ -463,15 +463,14 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers,
// the batch is passed as a second operand // the batch is passed as a second operand
gpuMemcpyAsync((void*)&batch, gpuMemcpyAsync((void*)&batch,
reinterpret_cast<const std::int64_t*>(buffers[1]), reinterpret_cast<const std::int64_t*>(buffers[1]),
sizeof(batch), gpuMemcpyDeviceToHost, sizeof(batch), gpuMemcpyDeviceToHost, stream);
stream);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
output_idx = 2; output_idx = 2;
} }
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
buffers[output_idx], buffers[0], buffers[output_idx], buffers[0],
SizeOfSolverType(d.type) * batch * SizeOfSolverType(d.type) * batch * static_cast<std::int64_t>(d.n) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.n), static_cast<std::int64_t>(d.n),
gpuMemcpyDeviceToDevice, stream))); gpuMemcpyDeviceToDevice, stream)));
gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR;
int* info = static_cast<int*>(buffers[output_idx + 2]); int* info = static_cast<int*>(buffers[output_idx + 2]);
@ -662,11 +661,13 @@ static absl::Status Gesvd_(gpuStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream); auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status()); JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h; auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( if (buffers[1] != buffers[0]) {
buffers[1], buffers[0], JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) * buffers[1], buffers[0],
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n), SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
gpuMemcpyDeviceToDevice, stream))); static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
gpuMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[5]); int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6]; void* work = buffers[6];
int64_t k = d.jobu == 'A' ? d.m : d.n; int64_t k = d.jobu == 'A' ? d.m : d.n;
@ -767,27 +768,37 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
auto h = SolverHandlePool::Borrow(stream); auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status()); JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h; auto& handle = *h;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( if (buffers[1] != buffers[0]) {
buffers[1], buffers[0], JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) * buffers[1], buffers[0],
static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n), SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
gpuMemcpyDeviceToDevice, stream))); static_cast<std::int64_t>(d.m) * static_cast<std::int64_t>(d.n),
gpuMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[5]); int* info = static_cast<int*>(buffers[5]);
void* work = buffers[6]; void* work = buffers[6];
gesvdjInfo_t params; gesvdjInfo_t params;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(&params))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCreateGesvdjInfo(&params)));
std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup( std::unique_ptr<gesvdjInfo, void (*)(gesvdjInfo*)> params_cleanup(
params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); }); params, [](gesvdjInfo* p) { cusolverDnDestroyGesvdjInfo(p); });
if (d.batch == 1) { if (d.batch <= 1 || d.m > 32 || d.n > 32 || d.econ) {
int k = std::min(d.m, d.n);
switch (d.type) { switch (d.type) {
case SolverType::F32: { case SolverType::F32: {
float* a = static_cast<float*>(buffers[1]); float* a = static_cast<float*>(buffers[1]);
float* s = static_cast<float*>(buffers[2]); float* s = static_cast<float*>(buffers[2]);
float* u = static_cast<float*>(buffers[3]); float* u = static_cast<float*>(buffers[3]);
float* v = static_cast<float*>(buffers[4]); float* v = static_cast<float*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj( for (int i = 0; i < d.batch; ++i) {
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnSgesvdj(
static_cast<float*>(work), d.lwork, info, params))); handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
static_cast<float*>(work), d.lwork, info, params)));
a += d.m * d.n;
s += k;
u += d.m * (d.econ ? k : d.m);
v += (d.econ ? k : d.n) * d.n;
++info;
}
break; break;
} }
case SolverType::F64: { case SolverType::F64: {
@ -795,9 +806,16 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
double* s = static_cast<double*>(buffers[2]); double* s = static_cast<double*>(buffers[2]);
double* u = static_cast<double*>(buffers[3]); double* u = static_cast<double*>(buffers[3]);
double* v = static_cast<double*>(buffers[4]); double* v = static_cast<double*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj( for (int i = 0; i < d.batch; ++i) {
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnDgesvdj(
static_cast<double*>(work), d.lwork, info, params))); handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
static_cast<double*>(work), d.lwork, info, params)));
a += d.m * d.n;
s += k;
u += d.m * (d.econ ? k : d.m);
v += (d.econ ? k : d.n) * d.n;
++info;
}
break; break;
} }
case SolverType::C64: { case SolverType::C64: {
@ -805,9 +823,16 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
float* s = static_cast<float*>(buffers[2]); float* s = static_cast<float*>(buffers[2]);
gpuComplex* u = static_cast<gpuComplex*>(buffers[3]); gpuComplex* u = static_cast<gpuComplex*>(buffers[3]);
gpuComplex* v = static_cast<gpuComplex*>(buffers[4]); gpuComplex* v = static_cast<gpuComplex*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj( for (int i = 0; i < d.batch; ++i) {
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnCgesvdj(
static_cast<gpuComplex*>(work), d.lwork, info, params))); handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
static_cast<gpuComplex*>(work), d.lwork, info, params)));
a += d.m * d.n;
s += k;
u += d.m * (d.econ ? k : d.m);
v += (d.econ ? k : d.n) * d.n;
++info;
}
break; break;
} }
case SolverType::C128: { case SolverType::C128: {
@ -815,9 +840,16 @@ static absl::Status Gesvdj_(gpuStream_t stream, void** buffers,
double* s = static_cast<double*>(buffers[2]); double* s = static_cast<double*>(buffers[2]);
gpuDoubleComplex* u = static_cast<gpuDoubleComplex*>(buffers[3]); gpuDoubleComplex* u = static_cast<gpuDoubleComplex*>(buffers[3]);
gpuDoubleComplex* v = static_cast<gpuDoubleComplex*>(buffers[4]); gpuDoubleComplex* v = static_cast<gpuDoubleComplex*>(buffers[4]);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj( for (int i = 0; i < d.batch; ++i) {
handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n, JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cusolverDnZgesvdj(
static_cast<gpuDoubleComplex*>(work), d.lwork, info, params))); handle.get(), d.jobz, d.econ, d.m, d.n, a, d.m, s, u, d.m, v, d.n,
static_cast<gpuDoubleComplex*>(work), d.lwork, info, params)));
a += d.m * d.n;
s += k;
u += d.m * (d.econ ? k : d.m);
v += (d.econ ? k : d.n) * d.n;
++info;
}
break; break;
} }
} }

View File

@ -22,15 +22,15 @@ limitations under the License.
#if defined(JAX_GPU_CUDA) #if defined(JAX_GPU_CUDA)
#include "third_party/gpus/cuda/include/cuComplex.h" #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cufft.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusolverDn.h" #include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export
#include "third_party/gpus/cuda/include/cusparse.h" #include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" #include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export
#include "third_party/gpus/cudnn/cudnn.h" #include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export
// Some sparse functionality is only available in CUSPARSE 11.3 or newer. // Some sparse functionality is only available in CUSPARSE 11.3 or newer.
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300) #define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
@ -74,7 +74,7 @@ typedef CUevent gpuEvent_t;
typedef CUfunction gpuFunction_t; typedef CUfunction gpuFunction_t;
typedef cudnnHandle_t gpudnnHandle_t; typedef cudnnHandle_t gpudnnHandle_t;
typedef cudnnStatus_t gpudnnStatus_t; typedef cudnnStatus_t gpudnnStatus_t;
typedef CUmodule gpuModule_t; typedef CUmodule gpuModule_t;
typedef cusolverDnHandle_t gpusolverDnHandle_t; typedef cusolverDnHandle_t gpusolverDnHandle_t;
typedef cusolverStatus_t gpusolverStatus_t; typedef cusolverStatus_t gpusolverStatus_t;
typedef cusolverEigMode_t gpusolverEigMode_t; typedef cusolverEigMode_t gpusolverEigMode_t;
@ -266,19 +266,24 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuInit cuInit #define gpuInit cuInit
#define gpuLaunchKernel cuLaunchKernel #define gpuLaunchKernel cuLaunchKernel
#define gpuMemcpyDtoHAsync cuMemcpyDtoHAsync #define gpuMemcpyDtoHAsync cuMemcpyDtoHAsync
#define gpuMemcpyHtoDAsync cuMemcpyHtoDAsync #define gpuMemcpyHtoDAsync cuMemcpyHtoDAsync
#define gpuMemsetD8Async cuMemsetD8Async #define gpuMemsetD8Async cuMemsetD8Async
#define gpuModuleLoadData cuModuleLoadData #define gpuModuleLoadData cuModuleLoadData
#define gpuModuleGetFunction cuModuleGetFunction #define gpuModuleGetFunction cuModuleGetFunction
#define gpuModuleUnload cuModuleUnload #define gpuModuleUnload cuModuleUnload
#define gpuStreamGetCtx cuStreamGetCtx #define gpuStreamGetCtx cuStreamGetCtx
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR #define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN #define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR \
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN \
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR \
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR
#define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES #define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
#define GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES #define GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES \
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES
#define GPU_EVENT_DEFAULT CU_EVENT_DEFAULT #define GPU_EVENT_DEFAULT CU_EVENT_DEFAULT
#define gpuGetLastError cudaGetLastError #define gpuGetLastError cudaGetLastError
@ -293,9 +298,9 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
namespace jax::JAX_GPU_NAMESPACE { namespace jax::JAX_GPU_NAMESPACE {
namespace { namespace {
constexpr uint32_t kNumThreadsPerWarp = 32; constexpr uint32_t kNumThreadsPerWarp = 32;
}
} }
} // namespace jax::JAX_GPU_NAMESPACE
#elif defined(JAX_GPU_HIP) #elif defined(JAX_GPU_HIP)
@ -483,7 +488,6 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT #define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS #define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
#define gpuGetLastError hipGetLastError #define gpuGetLastError hipGetLastError
#define gpuGetErrorString hipGetErrorString #define gpuGetErrorString hipGetErrorString
#define gpuMemcpyAsync hipMemcpyAsync #define gpuMemcpyAsync hipMemcpyAsync
@ -514,21 +518,26 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuModuleUnload hipModuleUnload #define gpuModuleUnload hipModuleUnload
#define gpuMemsetD8Async hipMemsetD8Async #define gpuMemsetD8Async hipMemsetD8Async
#define gpuMemcpyDtoHAsync hipMemcpyDtoHAsync #define gpuMemcpyDtoHAsync hipMemcpyDtoHAsync
#define gpuMemcpyHtoDAsync hipMemcpyHtoDAsync #define gpuMemcpyHtoDAsync hipMemcpyHtoDAsync
#define gpuMemsetD8Async hipMemsetD8Async #define gpuMemsetD8Async hipMemsetD8Async
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR hipDeviceAttributeComputeCapabilityMajor #define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR \
#define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR hipDeviceAttributeComputeCapabilityMinor hipDeviceAttributeComputeCapabilityMajor
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN hipDeviceAttributeMaxSharedMemoryPerBlock #define GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR \
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR hipDeviceAttributeMaxBlocksPerMultiProcessor hipDeviceAttributeComputeCapabilityMinor
#define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES #define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN \
hipDeviceAttributeMaxSharedMemoryPerBlock
#define GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR \
hipDeviceAttributeMaxBlocksPerMultiProcessor
#define GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES \
HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
#define GPU_EVENT_DEFAULT hipEventDefault #define GPU_EVENT_DEFAULT hipEventDefault
namespace jax::JAX_GPU_NAMESPACE { namespace jax::JAX_GPU_NAMESPACE {
namespace { namespace {
constexpr uint32_t kNumThreadsPerWarp = 64; constexpr uint32_t kNumThreadsPerWarp = 64;
}
} }
} // namespace jax::JAX_GPU_NAMESPACE
#else // defined(GPU vendor) #else // defined(GPU vendor)
#error "Either JAX_GPU_CUDA or JAX_GPU_HIP must be defined" #error "Either JAX_GPU_CUDA or JAX_GPU_HIP must be defined"

View File

@ -374,11 +374,10 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
# outperform gesvd for small-moderate matrices, e.g., see: # outperform gesvd for small-moderate matrices, e.g., see:
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf # https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
# slide 5. # slide 5.
if have_jacobi_solver and ( if have_jacobi_solver and m <= 1024 and n <= 1024:
(b == 1 and m <= 1024 and n <= 1024) or (m <= 32 and n <= 32) # The gesvdjbatched kernel doesn't support "econ" mode. We will use that
): # kernel only if b > 1 and m <= 32 and n <= 32.
# The batched kernel doesn't support "econ" mode. econ = not full_matrices and (b <= 1 or m > 32 or n > 32)
econ = not full_matrices and b == 1
lwork, opaque = gpu_solver.build_gesvdj_descriptor( lwork, opaque = gpu_solver.build_gesvdj_descriptor(
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0) np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
k = min(m, n) k = min(m, n)

View File

@ -626,7 +626,7 @@ class NumpyLinalgTest(jtu.JaxTestCase):
tol = 80 * jnp.finfo(dtype).eps tol = 80 * jnp.finfo(dtype).eps
reconstruction_tol = 2 * tol reconstruction_tol = 2 * tol
unitariness_tol = tol unitariness_tol = 3 * tol
a, = args_maker() a, = args_maker()
if hermitian: if hermitian: