Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.

Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
This commit is contained in:
Peter Hawkins 2022-11-10 13:15:44 -08:00 committed by jax authors
parent 74b136e62c
commit 352b042fe9
9 changed files with 305 additions and 44 deletions

View File

@ -213,6 +213,7 @@ Linear algebra operators (jax.lax.linalg)
schur
svd
triangular_solve
tridiagonal
tridiagonal_solve
Argument classes

View File

@ -43,6 +43,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lib import lapack
from jax._src.lib import mlir_api_version
from jax._src.lib import version as jaxlib_version
from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
@ -1929,6 +1930,8 @@ ad.primitive_jvps[schur_p] = _schur_jvp_rule
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
"""Reduces a square matrix to upper Hessenberg form.
Currently implemented on CPU only.
Args:
a: A floating point or complex square matrix or batch of matrices.
@ -1942,13 +1945,15 @@ def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
return hessenberg_p.bind(a)
def _hessenberg_abstract_eval(a):
if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128):
raise TypeError("hessenberg requires a.dtype to be float32, float64, "
f"complex64, or complex128, got {a.dtype}.")
if a.ndim < 2:
msg = "hessenberg requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
raise TypeError("hessenberg requires a.ndim to be at least 2, got "
f"{a.ndim}.")
if a.shape[-1] != a.shape[-2]:
msg = ("hessenberg requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
raise TypeError("hessenberg requires the last two dimensions of a to be "
f"equal in size, got a.shape of {a.shape}.")
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
hessenberg_p = Primitive("hessenberg")
@ -1995,46 +2000,61 @@ mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
# tridiagonal: Upper Hessenberg reduction
def tridiagonal(a: ArrayLike, *, lower=True) -> Tuple[Array, Array]:
def tridiagonal(a: ArrayLike, *, lower=True
) -> Tuple[Array, Array, Array, Array]:
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
Currently implemented on CPU and GPU only.
Args:
a: A floating point or complex matrix or batch of matrices.
lower: Describes which triangle of the input matrices to use.
The other triangle is ignored and not accessed.
Returns:
A ``(a, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
A ``(a, d, e, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
and elements below the first subdiagonal contain the elementary Householder
reflectors. If ``lower=False`` the diagonal and first superdiagonal of the
reflectors, where additionally ``d`` contains the diagonal of the matrix and ``e`` contains
the first subdiagonal.If ``lower=False`` the diagonal and first superdiagonal of the
matrix contains the tridiagonal representation, and elements above the first
superdiagonal contain the elementary Householder reflectors.
``taus`` contains the scalar factors of the elementary Householder
reflectors.
superdiagonal contain the elementary Householder reflectors, where
additionally ``d`` contains the diagonal of the matrix and ``e`` contains the
first superdiagonal. ``taus`` contains the scalar factors of the elementary
Householder reflectors.
"""
arr, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
arr, d, e, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
nan = arr.dtype.type(jnp.nan)
if jnp.issubdtype(arr.dtype, np.complexfloating):
nan = nan + arr.dtype.type(jnp.nan * 1j)
arr = jnp.where((info == 0)[..., None, None], arr, nan)
real_type = jnp.finfo(arr.dtype).dtype.type
d = jnp.where((info == 0)[..., None], d, real_type(jnp.nan))
e = jnp.where((info == 0)[..., None], e, real_type(jnp.nan))
taus = jnp.where((info == 0)[..., None], taus, nan)
return arr, taus
return arr, d, e, taus
def _tridiagonal_abstract_eval(a, *, lower):
if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128):
raise TypeError("tridiagonal requires a.dtype to be float32, float64, "
f"complex64, or complex128, got {a.dtype}.")
if a.ndim < 2:
msg = "tridiagonal requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
raise TypeError("tridiagonal requires a.ndim to be at least 2, got "
f"{a.ndim}.")
if a.shape[-1] != a.shape[-2]:
msg = ("tridiagonal requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
raise TypeError("tridiagonal requires the last two dimensions of a to be "
f"equal in size, got a.shape of {a.shape}.")
if a.shape[-1] == 0:
msg = ("tridiagonal requires the last two dimensions of a to be non-zero, "
"got a.shape of {}.")
raise TypeError(msg.format(a.shape))
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
ShapedArray(a.shape[:-2], np.int32)]
raise TypeError("tridiagonal requires the last two dimensions of a to be "
f"non-zero, got a.shape of {a.shape}.")
real_dtype = jnp.finfo(a.dtype).dtype
return [
a,
ShapedArray(a.shape[:-2] + (a.shape[-1],), real_dtype),
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), real_dtype),
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
ShapedArray(a.shape[:-2], np.int32)
]
tridiagonal_p = Primitive("tridiagonal")
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
@ -2049,18 +2069,21 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
def _tridiagonal_cpu_mhlo(ctx, a, *, lower):
def _tridiagonal_cpu_gpu_mhlo(sytrd_impl, ctx, a, *, lower):
a_aval, = ctx.avals_in
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
if not hasattr(lapack, "sytrd_mhlo"):
raise RuntimeError("Tridiagonal reduction on CPU requires jaxlib 0.3.25 or "
"newer")
a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower)
return a, d, e, taus, info
a, d, e, taus, info = lapack.sytrd_mhlo(a_aval.dtype, a, lower=lower)
del d, e
return a, taus, info
mlir.register_lowering(tridiagonal_p, _tridiagonal_cpu_mhlo, platform='cpu')
if jaxlib_version >= (0, 3, 25):
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, lapack.sytrd_mhlo),
platform='cpu')
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.cuda_sytrd),
platform='cuda')
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.rocm_sytrd),
platform='rocm')

View File

@ -513,6 +513,41 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
#endif // JAX_GPU_CUDA
// Returns the workspace size and a descriptor for a geqrf operation.
std::pair<int, py::bytes> BuildSytrdDescriptor(const py::dtype& dtype,
bool lower, int b, int n) {
SolverType type = DtypeToSolverType(dtype);
auto h = SolverHandlePool::Borrow();
JAX_THROW_IF_ERROR(h.status());
auto& handle = *h;
int lwork;
gpusolverFillMode_t uplo =
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
switch (type) {
case SolverType::F32:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd_bufferSize(
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
break;
case SolverType::F64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd_bufferSize(
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
break;
case SolverType::C64:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd_bufferSize(
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
break;
case SolverType::C128:
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd_bufferSize(
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
break;
}
return {lwork, PackDescriptor(SytrdDescriptor{type, uplo, b, n, n, lwork})};
}
py::dict Registrations() {
py::dict dict;
dict[JAX_GPU_PREFIX "solver_potrf"] = EncapsulateFunction(Potrf);
@ -522,6 +557,7 @@ py::dict Registrations() {
dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd);
dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj);
dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd);
dict[JAX_GPU_PREFIX "solver_sytrd"] = EncapsulateFunction(Sytrd);
#ifdef JAX_GPU_CUDA
dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr);
@ -539,6 +575,7 @@ PYBIND11_MODULE(_solver, m) {
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
m.def("build_sytrd_descriptor", &BuildSytrdDescriptor);
#ifdef JAX_GPU_CUDA
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor);

View File

@ -973,5 +973,107 @@ void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque,
#endif // JAX_GPU_CUDA
// sytrd/hetrd: symmetric (Hermitian) tridiagonal reduction
static absl::Status Sytrd_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<SytrdDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const SytrdDescriptor& d = **s;
auto h = SolverHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
if (buffers[1] != buffers[0]) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
buffers[1], buffers[0],
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.lda),
gpuMemcpyDeviceToDevice, stream)));
}
int* info = static_cast<int*>(buffers[5]);
void* workspace = buffers[6];
switch (d.type) {
case SolverType::F32: {
float* a = static_cast<float*>(buffers[1]);
float* d_out = static_cast<float*>(buffers[2]);
float* e_out = static_cast<float*>(buffers[3]);
float* tau = static_cast<float*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd(
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
static_cast<float*>(workspace), d.lwork, info)));
a += d.lda * d.n;
d_out += d.n;
e_out += d.n - 1;
tau += d.n - 1;
++info;
}
break;
}
case SolverType::F64: {
double* a = static_cast<double*>(buffers[1]);
double* d_out = static_cast<double*>(buffers[2]);
double* e_out = static_cast<double*>(buffers[3]);
double* tau = static_cast<double*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd(
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
static_cast<double*>(workspace), d.lwork, info)));
a += d.lda * d.n;
d_out += d.n;
e_out += d.n - 1;
tau += d.n - 1;
++info;
}
break;
}
case SolverType::C64: {
gpuComplex* a = static_cast<gpuComplex*>(buffers[1]);
float* d_out = static_cast<float*>(buffers[2]);
float* e_out = static_cast<float*>(buffers[3]);
gpuComplex* tau = static_cast<gpuComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd(
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
static_cast<gpuComplex*>(workspace), d.lwork, info)));
a += d.lda * d.n;
d_out += d.n;
e_out += d.n - 1;
tau += d.n - 1;
++info;
}
break;
}
case SolverType::C128: {
gpuDoubleComplex* a = static_cast<gpuDoubleComplex*>(buffers[1]);
double* d_out = static_cast<double*>(buffers[2]);
double* e_out = static_cast<double*>(buffers[3]);
gpuDoubleComplex* tau = static_cast<gpuDoubleComplex*>(buffers[4]);
for (int i = 0; i < d.batch; ++i) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd(
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
static_cast<gpuDoubleComplex*>(workspace), d.lwork, info)));
a += d.lda * d.n;
d_out += d.n;
e_out += d.n - 1;
tau += d.n - 1;
++info;
}
break;
}
}
return absl::OkStatus();
}
void Sytrd(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = Sytrd_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

View File

@ -161,10 +161,19 @@ struct GesvdjDescriptor {
void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
// sytrd/hetrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
struct SytrdDescriptor {
SolverType type;
gpusolverFillMode_t uplo;
int batch, n, lda, lwork;
};
void Sytrd(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
#endif // JAX_GPU_CUDA
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_CUSOLVER_KERNELS_H_

View File

@ -185,6 +185,14 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
cusolverDnCgesvd_bufferSize(h, m, n, lwork)
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
cusolverDnZgesvd_bufferSize(h, m, n, lwork)
#define gpusolverDnSsytrd_bufferSize cusolverDnSsytrd_bufferSize
#define gpusolverDnDsytrd_bufferSize cusolverDnDsytrd_bufferSize
#define gpusolverDnChetrd_bufferSize cusolverDnChetrd_bufferSize
#define gpusolverDnZhetrd_bufferSize cusolverDnZhetrd_bufferSize
#define gpusolverDnSsytrd cusolverDnSsytrd
#define gpusolverDnDsytrd cusolverDnDsytrd
#define gpusolverDnChetrd cusolverDnChetrd
#define gpusolverDnZhetrd cusolverDnZhetrd
#define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER
#define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER
@ -397,6 +405,14 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
hipsolverCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
hipsolverZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
#define gpusolverDnSsytrd_bufferSize hipsolverDnSsytrd_bufferSize
#define gpusolverDnDsytrd_bufferSize hipsolverDnDsytrd_bufferSize
#define gpusolverDnChetrd_bufferSize hipsolverDnChetrd_bufferSize
#define gpusolverDnZhetrd_bufferSize hipsolverDnZhetrd_bufferSize
#define gpusolverDnSsytrd hipsolverDnSsytrd
#define gpusolverDnDsytrd hipsolverDnDsytrd
#define gpusolverDnChetrd hipsolverDnChetrd
#define gpusolverDnZhetrd hipsolverDnZhetrd
#define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER
#define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER

View File

@ -468,3 +468,75 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True)
rocm_gesvd = partial(_gesvd_mhlo, "hip", _hipsolver, False)
def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower):
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
lwork, opaque = gpu_solver.build_sytrd_descriptor(dtype, lower, b, n)
if np.issubdtype(dtype, np.floating):
diag_type = a_type.element_type
elif dtype == np.complex64:
diag_type = ir.F32Type.get()
elif dtype == np.complex128:
diag_type = ir.F64Type.get()
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
a, d, e, taus, info, _ = custom_call(
f"{platform}solver_sytrd",
[
a.type,
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0},
)
# Workaround for NVIDIA partners bug #3865118: sytrd returns an incorrect "1"
# in the first element of the superdiagonal in the `a` matrix in the
# lower=False case. The correct result is returned in the `e` vector so we can
# simply copy it back to where it needs to be:
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
if not lower and platform == "cu" and m > 1:
start = (0,) * len(batch_dims) + (0,)
end = batch_dims + (1,)
s = mhlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
s = mhlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
# The diagonals are always real; convert to complex if needed.
s = mhlo.ConvertOp(
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
offsets = tuple(mhlo.ConstantOp(intattr(i))
for i in ((0,) * len(batch_dims) + (0, 1)))
a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result
return a, d, e, taus, info
cuda_sytrd = partial(_sytrd_mhlo, "cu", _cusolver)
rocm_sytrd = partial(_sytrd_mhlo, "hip", _hipsolver)

View File

@ -708,11 +708,11 @@ def sytrd_mhlo(dtype, a, *, lower):
elif dtype == np.complex64:
fn = b"lapack_chetrd"
lwork = _lapack.lapack_chetrd_workspace(n, n)
diag_type = ir.ComplexType.get(ir.F32Type.get())
diag_type = ir.F32Type.get()
elif dtype == np.complex128:
fn = b"lapack_zhetrd"
lwork = _lapack.lapack_zhetrd_workspace(n, n)
diag_type = ir.ComplexType.get(ir.F64Type.get())
diag_type = ir.F64Type.get()
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

View File

@ -991,7 +991,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
p, l, u = jsp.linalg.lu(x)
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
rtol={np.float32: 1e-3, np.float64: 1e-12,
np.complex64: 1e-3, np.complex128: 1e-12})
np.complex64: 1e-3, np.complex128: 1e-12},
atol={np.float32: 1e-5})
self._CompileAndCheck(jsp.linalg.lu, args_maker)
def testLuOfSingularMatrix(self):
@ -1307,19 +1308,20 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self._CompileAndCheck(jsp_func, args_maker)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (10, 10)],
shape=[(1, 1), (2, 2, 2), (4, 4), (10, 10), (2, 5, 5)],
dtype=float_types + complex_types,
lower=[False, True],
)
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
@jtu.skip_on_devices("gpu", "tpu")
@jtu.skip_on_devices("tpu")
def testTridiagonal(self, shape, dtype, lower):
rng = jtu.rand_default(self.rng())
def jax_func(a):
return lax.linalg.tridiagonal(a, lower=lower)
@partial(np.vectorize, otypes=(dtype, dtype),
signature='(n,n)->(n,n),(k)')
real_dtype = jnp.finfo(dtype).dtype
@partial(np.vectorize, otypes=(dtype, real_dtype, real_dtype, dtype),
signature='(n,n)->(n,n),(n),(k),(k)')
def sp_func(a):
if dtype == np.float32:
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
@ -1331,12 +1333,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
else:
assert False, dtype
del d, e
assert info == 0
return c, tau
return c, d, e, tau
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-5, atol=1e-5,
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4,
check_dtypes=False)