mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility. Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper. PiperOrigin-RevId: 487621469
This commit is contained in:
parent
74b136e62c
commit
352b042fe9
@ -213,6 +213,7 @@ Linear algebra operators (jax.lax.linalg)
|
||||
schur
|
||||
svd
|
||||
triangular_solve
|
||||
tridiagonal
|
||||
tridiagonal_solve
|
||||
|
||||
Argument classes
|
||||
|
@ -43,6 +43,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import mlir_api_version
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
|
||||
from jax._src.lib import gpu_linalg
|
||||
from jax._src.lib import gpu_solver
|
||||
@ -1929,6 +1930,8 @@ ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
||||
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
|
||||
"""Reduces a square matrix to upper Hessenberg form.
|
||||
|
||||
Currently implemented on CPU only.
|
||||
|
||||
Args:
|
||||
a: A floating point or complex square matrix or batch of matrices.
|
||||
|
||||
@ -1942,13 +1945,15 @@ def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
|
||||
return hessenberg_p.bind(a)
|
||||
|
||||
def _hessenberg_abstract_eval(a):
|
||||
if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128):
|
||||
raise TypeError("hessenberg requires a.dtype to be float32, float64, "
|
||||
f"complex64, or complex128, got {a.dtype}.")
|
||||
if a.ndim < 2:
|
||||
msg = "hessenberg requires a.ndim to be at least 2, got {}."
|
||||
raise TypeError(msg.format(a.ndim))
|
||||
raise TypeError("hessenberg requires a.ndim to be at least 2, got "
|
||||
f"{a.ndim}.")
|
||||
if a.shape[-1] != a.shape[-2]:
|
||||
msg = ("hessenberg requires the last two dimensions of a to be equal "
|
||||
"in size, got a.shape of {}.")
|
||||
raise TypeError(msg.format(a.shape))
|
||||
raise TypeError("hessenberg requires the last two dimensions of a to be "
|
||||
f"equal in size, got a.shape of {a.shape}.")
|
||||
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
|
||||
|
||||
hessenberg_p = Primitive("hessenberg")
|
||||
@ -1995,46 +2000,61 @@ mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
|
||||
|
||||
# tridiagonal: Upper Hessenberg reduction
|
||||
|
||||
def tridiagonal(a: ArrayLike, *, lower=True) -> Tuple[Array, Array]:
|
||||
def tridiagonal(a: ArrayLike, *, lower=True
|
||||
) -> Tuple[Array, Array, Array, Array]:
|
||||
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
|
||||
|
||||
Currently implemented on CPU and GPU only.
|
||||
|
||||
Args:
|
||||
a: A floating point or complex matrix or batch of matrices.
|
||||
lower: Describes which triangle of the input matrices to use.
|
||||
The other triangle is ignored and not accessed.
|
||||
|
||||
Returns:
|
||||
A ``(a, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
|
||||
A ``(a, d, e, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
|
||||
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
|
||||
and elements below the first subdiagonal contain the elementary Householder
|
||||
reflectors. If ``lower=False`` the diagonal and first superdiagonal of the
|
||||
reflectors, where additionally ``d`` contains the diagonal of the matrix and ``e`` contains
|
||||
the first subdiagonal.If ``lower=False`` the diagonal and first superdiagonal of the
|
||||
matrix contains the tridiagonal representation, and elements above the first
|
||||
superdiagonal contain the elementary Householder reflectors.
|
||||
``taus`` contains the scalar factors of the elementary Householder
|
||||
reflectors.
|
||||
superdiagonal contain the elementary Householder reflectors, where
|
||||
additionally ``d`` contains the diagonal of the matrix and ``e`` contains the
|
||||
first superdiagonal. ``taus`` contains the scalar factors of the elementary
|
||||
Householder reflectors.
|
||||
"""
|
||||
arr, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
|
||||
arr, d, e, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
|
||||
nan = arr.dtype.type(jnp.nan)
|
||||
if jnp.issubdtype(arr.dtype, np.complexfloating):
|
||||
nan = nan + arr.dtype.type(jnp.nan * 1j)
|
||||
arr = jnp.where((info == 0)[..., None, None], arr, nan)
|
||||
real_type = jnp.finfo(arr.dtype).dtype.type
|
||||
d = jnp.where((info == 0)[..., None], d, real_type(jnp.nan))
|
||||
e = jnp.where((info == 0)[..., None], e, real_type(jnp.nan))
|
||||
taus = jnp.where((info == 0)[..., None], taus, nan)
|
||||
return arr, taus
|
||||
return arr, d, e, taus
|
||||
|
||||
def _tridiagonal_abstract_eval(a, *, lower):
|
||||
if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128):
|
||||
raise TypeError("tridiagonal requires a.dtype to be float32, float64, "
|
||||
f"complex64, or complex128, got {a.dtype}.")
|
||||
if a.ndim < 2:
|
||||
msg = "tridiagonal requires a.ndim to be at least 2, got {}."
|
||||
raise TypeError(msg.format(a.ndim))
|
||||
raise TypeError("tridiagonal requires a.ndim to be at least 2, got "
|
||||
f"{a.ndim}.")
|
||||
if a.shape[-1] != a.shape[-2]:
|
||||
msg = ("tridiagonal requires the last two dimensions of a to be equal "
|
||||
"in size, got a.shape of {}.")
|
||||
raise TypeError(msg.format(a.shape))
|
||||
raise TypeError("tridiagonal requires the last two dimensions of a to be "
|
||||
f"equal in size, got a.shape of {a.shape}.")
|
||||
if a.shape[-1] == 0:
|
||||
msg = ("tridiagonal requires the last two dimensions of a to be non-zero, "
|
||||
"got a.shape of {}.")
|
||||
raise TypeError(msg.format(a.shape))
|
||||
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
|
||||
ShapedArray(a.shape[:-2], np.int32)]
|
||||
raise TypeError("tridiagonal requires the last two dimensions of a to be "
|
||||
f"non-zero, got a.shape of {a.shape}.")
|
||||
real_dtype = jnp.finfo(a.dtype).dtype
|
||||
return [
|
||||
a,
|
||||
ShapedArray(a.shape[:-2] + (a.shape[-1],), real_dtype),
|
||||
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), real_dtype),
|
||||
ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
|
||||
ShapedArray(a.shape[:-2], np.int32)
|
||||
]
|
||||
|
||||
tridiagonal_p = Primitive("tridiagonal")
|
||||
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
|
||||
@ -2049,18 +2069,21 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
||||
|
||||
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
||||
|
||||
def _tridiagonal_cpu_mhlo(ctx, a, *, lower):
|
||||
def _tridiagonal_cpu_gpu_mhlo(sytrd_impl, ctx, a, *, lower):
|
||||
a_aval, = ctx.avals_in
|
||||
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
|
||||
if not hasattr(lapack, "sytrd_mhlo"):
|
||||
raise RuntimeError("Tridiagonal reduction on CPU requires jaxlib 0.3.25 or "
|
||||
"newer")
|
||||
a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower)
|
||||
return a, d, e, taus, info
|
||||
|
||||
a, d, e, taus, info = lapack.sytrd_mhlo(a_aval.dtype, a, lower=lower)
|
||||
del d, e
|
||||
return a, taus, info
|
||||
|
||||
mlir.register_lowering(tridiagonal_p, _tridiagonal_cpu_mhlo, platform='cpu')
|
||||
if jaxlib_version >= (0, 3, 25):
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, lapack.sytrd_mhlo),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.cuda_sytrd),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.rocm_sytrd),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
|
||||
|
@ -513,6 +513,41 @@ std::pair<int, py::bytes> BuildGesvdjDescriptor(const py::dtype& dtype,
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
// Returns the workspace size and a descriptor for a geqrf operation.
|
||||
std::pair<int, py::bytes> BuildSytrdDescriptor(const py::dtype& dtype,
|
||||
bool lower, int b, int n) {
|
||||
SolverType type = DtypeToSolverType(dtype);
|
||||
auto h = SolverHandlePool::Borrow();
|
||||
JAX_THROW_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
int lwork;
|
||||
gpusolverFillMode_t uplo =
|
||||
lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER;
|
||||
switch (type) {
|
||||
case SolverType::F32:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd_bufferSize(
|
||||
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
|
||||
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case SolverType::F64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd_bufferSize(
|
||||
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
|
||||
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case SolverType::C64:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd_bufferSize(
|
||||
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
|
||||
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
case SolverType::C128:
|
||||
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd_bufferSize(
|
||||
handle.get(), uplo, n, /*A=*/nullptr, /*lda=*/n, /*D=*/nullptr,
|
||||
/*E=*/nullptr, /*tau=*/nullptr, &lwork)));
|
||||
break;
|
||||
}
|
||||
return {lwork, PackDescriptor(SytrdDescriptor{type, uplo, b, n, n, lwork})};
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict[JAX_GPU_PREFIX "solver_potrf"] = EncapsulateFunction(Potrf);
|
||||
@ -522,6 +557,7 @@ py::dict Registrations() {
|
||||
dict[JAX_GPU_PREFIX "solver_syevd"] = EncapsulateFunction(Syevd);
|
||||
dict[JAX_GPU_PREFIX "solver_syevj"] = EncapsulateFunction(Syevj);
|
||||
dict[JAX_GPU_PREFIX "solver_gesvd"] = EncapsulateFunction(Gesvd);
|
||||
dict[JAX_GPU_PREFIX "solver_sytrd"] = EncapsulateFunction(Sytrd);
|
||||
|
||||
#ifdef JAX_GPU_CUDA
|
||||
dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr);
|
||||
@ -539,6 +575,7 @@ PYBIND11_MODULE(_solver, m) {
|
||||
m.def("build_syevd_descriptor", &BuildSyevdDescriptor);
|
||||
m.def("build_syevj_descriptor", &BuildSyevjDescriptor);
|
||||
m.def("build_gesvd_descriptor", &BuildGesvdDescriptor);
|
||||
m.def("build_sytrd_descriptor", &BuildSytrdDescriptor);
|
||||
#ifdef JAX_GPU_CUDA
|
||||
m.def("build_csrlsvqr_descriptor", &BuildCsrlsvqrDescriptor);
|
||||
m.def("build_gesvdj_descriptor", &BuildGesvdjDescriptor);
|
||||
|
@ -973,5 +973,107 @@ void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
// sytrd/hetrd: symmetric (Hermitian) tridiagonal reduction
|
||||
|
||||
static absl::Status Sytrd_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<SytrdDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const SytrdDescriptor& d = **s;
|
||||
auto h = SolverHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
if (buffers[1] != buffers[0]) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync(
|
||||
buffers[1], buffers[0],
|
||||
SizeOfSolverType(d.type) * static_cast<std::int64_t>(d.batch) *
|
||||
static_cast<std::int64_t>(d.n) * static_cast<std::int64_t>(d.lda),
|
||||
gpuMemcpyDeviceToDevice, stream)));
|
||||
}
|
||||
|
||||
int* info = static_cast<int*>(buffers[5]);
|
||||
void* workspace = buffers[6];
|
||||
switch (d.type) {
|
||||
case SolverType::F32: {
|
||||
float* a = static_cast<float*>(buffers[1]);
|
||||
float* d_out = static_cast<float*>(buffers[2]);
|
||||
float* e_out = static_cast<float*>(buffers[3]);
|
||||
float* tau = static_cast<float*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSsytrd(
|
||||
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
|
||||
static_cast<float*>(workspace), d.lwork, info)));
|
||||
a += d.lda * d.n;
|
||||
d_out += d.n;
|
||||
e_out += d.n - 1;
|
||||
tau += d.n - 1;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case SolverType::F64: {
|
||||
double* a = static_cast<double*>(buffers[1]);
|
||||
double* d_out = static_cast<double*>(buffers[2]);
|
||||
double* e_out = static_cast<double*>(buffers[3]);
|
||||
double* tau = static_cast<double*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnDsytrd(
|
||||
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
|
||||
static_cast<double*>(workspace), d.lwork, info)));
|
||||
a += d.lda * d.n;
|
||||
d_out += d.n;
|
||||
e_out += d.n - 1;
|
||||
tau += d.n - 1;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case SolverType::C64: {
|
||||
gpuComplex* a = static_cast<gpuComplex*>(buffers[1]);
|
||||
float* d_out = static_cast<float*>(buffers[2]);
|
||||
float* e_out = static_cast<float*>(buffers[3]);
|
||||
gpuComplex* tau = static_cast<gpuComplex*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnChetrd(
|
||||
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
|
||||
static_cast<gpuComplex*>(workspace), d.lwork, info)));
|
||||
a += d.lda * d.n;
|
||||
d_out += d.n;
|
||||
e_out += d.n - 1;
|
||||
tau += d.n - 1;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case SolverType::C128: {
|
||||
gpuDoubleComplex* a = static_cast<gpuDoubleComplex*>(buffers[1]);
|
||||
double* d_out = static_cast<double*>(buffers[2]);
|
||||
double* e_out = static_cast<double*>(buffers[3]);
|
||||
gpuDoubleComplex* tau = static_cast<gpuDoubleComplex*>(buffers[4]);
|
||||
for (int i = 0; i < d.batch; ++i) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnZhetrd(
|
||||
handle.get(), d.uplo, d.n, a, d.lda, d_out, e_out, tau,
|
||||
static_cast<gpuDoubleComplex*>(workspace), d.lwork, info)));
|
||||
a += d.lda * d.n;
|
||||
d_out += d.n;
|
||||
e_out += d.n - 1;
|
||||
tau += d.n - 1;
|
||||
++info;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void Sytrd(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = Sytrd_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
@ -161,10 +161,19 @@ struct GesvdjDescriptor {
|
||||
void Gesvdj(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
// sytrd/hetrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
|
||||
struct SytrdDescriptor {
|
||||
SolverType type;
|
||||
gpusolverFillMode_t uplo;
|
||||
int batch, n, lda, lwork;
|
||||
};
|
||||
|
||||
void Sytrd(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
#endif // JAX_GPU_CUDA
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CUSOLVER_KERNELS_H_
|
||||
|
||||
|
@ -185,6 +185,14 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
cusolverDnCgesvd_bufferSize(h, m, n, lwork)
|
||||
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
cusolverDnZgesvd_bufferSize(h, m, n, lwork)
|
||||
#define gpusolverDnSsytrd_bufferSize cusolverDnSsytrd_bufferSize
|
||||
#define gpusolverDnDsytrd_bufferSize cusolverDnDsytrd_bufferSize
|
||||
#define gpusolverDnChetrd_bufferSize cusolverDnChetrd_bufferSize
|
||||
#define gpusolverDnZhetrd_bufferSize cusolverDnZhetrd_bufferSize
|
||||
#define gpusolverDnSsytrd cusolverDnSsytrd
|
||||
#define gpusolverDnDsytrd cusolverDnDsytrd
|
||||
#define gpusolverDnChetrd cusolverDnChetrd
|
||||
#define gpusolverDnZhetrd cusolverDnZhetrd
|
||||
|
||||
#define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER
|
||||
#define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER
|
||||
@ -397,6 +405,14 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
hipsolverCgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
|
||||
#define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \
|
||||
hipsolverZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork)
|
||||
#define gpusolverDnSsytrd_bufferSize hipsolverDnSsytrd_bufferSize
|
||||
#define gpusolverDnDsytrd_bufferSize hipsolverDnDsytrd_bufferSize
|
||||
#define gpusolverDnChetrd_bufferSize hipsolverDnChetrd_bufferSize
|
||||
#define gpusolverDnZhetrd_bufferSize hipsolverDnZhetrd_bufferSize
|
||||
#define gpusolverDnSsytrd hipsolverDnSsytrd
|
||||
#define gpusolverDnDsytrd hipsolverDnDsytrd
|
||||
#define gpusolverDnChetrd hipsolverDnChetrd
|
||||
#define gpusolverDnZhetrd hipsolverDnZhetrd
|
||||
|
||||
#define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER
|
||||
#define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER
|
||||
|
@ -468,3 +468,75 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
|
||||
cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True)
|
||||
rocm_gesvd = partial(_gesvd_mhlo, "hip", _hipsolver, False)
|
||||
|
||||
|
||||
def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n, (m, n)
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
|
||||
lwork, opaque = gpu_solver.build_sytrd_descriptor(dtype, lower, b, n)
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
diag_type = a_type.element_type
|
||||
elif dtype == np.complex64:
|
||||
diag_type = ir.F32Type.get()
|
||||
elif dtype == np.complex128:
|
||||
diag_type = ir.F64Type.get()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
a, d, e, taus, info, _ = custom_call(
|
||||
f"{platform}solver_sytrd",
|
||||
[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
[a],
|
||||
backend_config=opaque,
|
||||
operand_layouts=[layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0},
|
||||
)
|
||||
# Workaround for NVIDIA partners bug #3865118: sytrd returns an incorrect "1"
|
||||
# in the first element of the superdiagonal in the `a` matrix in the
|
||||
# lower=False case. The correct result is returned in the `e` vector so we can
|
||||
# simply copy it back to where it needs to be:
|
||||
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
|
||||
if not lower and platform == "cu" and m > 1:
|
||||
start = (0,) * len(batch_dims) + (0,)
|
||||
end = batch_dims + (1,)
|
||||
s = mhlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
|
||||
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
|
||||
s = mhlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
|
||||
# The diagonals are always real; convert to complex if needed.
|
||||
s = mhlo.ConvertOp(
|
||||
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
|
||||
offsets = tuple(mhlo.ConstantOp(intattr(i))
|
||||
for i in ((0,) * len(batch_dims) + (0, 1)))
|
||||
a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result
|
||||
|
||||
return a, d, e, taus, info
|
||||
|
||||
cuda_sytrd = partial(_sytrd_mhlo, "cu", _cusolver)
|
||||
rocm_sytrd = partial(_sytrd_mhlo, "hip", _hipsolver)
|
||||
|
@ -708,11 +708,11 @@ def sytrd_mhlo(dtype, a, *, lower):
|
||||
elif dtype == np.complex64:
|
||||
fn = b"lapack_chetrd"
|
||||
lwork = _lapack.lapack_chetrd_workspace(n, n)
|
||||
diag_type = ir.ComplexType.get(ir.F32Type.get())
|
||||
diag_type = ir.F32Type.get()
|
||||
elif dtype == np.complex128:
|
||||
fn = b"lapack_zhetrd"
|
||||
lwork = _lapack.lapack_zhetrd_workspace(n, n)
|
||||
diag_type = ir.ComplexType.get(ir.F64Type.get())
|
||||
diag_type = ir.F64Type.get()
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
|
@ -991,7 +991,8 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
p, l, u = jsp.linalg.lu(x)
|
||||
self.assertAllClose(x, np.matmul(p, np.matmul(l, u)),
|
||||
rtol={np.float32: 1e-3, np.float64: 1e-12,
|
||||
np.complex64: 1e-3, np.complex128: 1e-12})
|
||||
np.complex64: 1e-3, np.complex128: 1e-12},
|
||||
atol={np.float32: 1e-5})
|
||||
self._CompileAndCheck(jsp.linalg.lu, args_maker)
|
||||
|
||||
def testLuOfSingularMatrix(self):
|
||||
@ -1307,19 +1308,20 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
self._CompileAndCheck(jsp_func, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (4, 4), (10, 10)],
|
||||
shape=[(1, 1), (2, 2, 2), (4, 4), (10, 10), (2, 5, 5)],
|
||||
dtype=float_types + complex_types,
|
||||
lower=[False, True],
|
||||
)
|
||||
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def testTridiagonal(self, shape, dtype, lower):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
def jax_func(a):
|
||||
return lax.linalg.tridiagonal(a, lower=lower)
|
||||
|
||||
@partial(np.vectorize, otypes=(dtype, dtype),
|
||||
signature='(n,n)->(n,n),(k)')
|
||||
real_dtype = jnp.finfo(dtype).dtype
|
||||
@partial(np.vectorize, otypes=(dtype, real_dtype, real_dtype, dtype),
|
||||
signature='(n,n)->(n,n),(n),(k),(k)')
|
||||
def sp_func(a):
|
||||
if dtype == np.float32:
|
||||
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
|
||||
@ -1331,12 +1333,11 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
|
||||
else:
|
||||
assert False, dtype
|
||||
del d, e
|
||||
assert info == 0
|
||||
return c, tau
|
||||
return c, d, e, tau
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-5, atol=1e-5,
|
||||
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4,
|
||||
check_dtypes=False)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user