Add support for Hessenberg and tridiagonal matrix reductions on CPU.

* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
This commit is contained in:
Peter Hawkins 2022-11-09 06:23:22 -08:00 committed by jax authors
parent 30637d052b
commit 1cead779a3
14 changed files with 650 additions and 40 deletions

View File

@ -9,8 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.3.25
* Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
* {func}`jax.scipy.linalg.hessenberg` is now supported on CPU only. Requires
jaxlib > 0.3.24.
* New functions {func}`jax.lax.linalg.hessenberg`,
{func}`jax.lax.linalg.tridiagonal`, and
{func}`jax.lax.linalg.householder_product` were added. Householder and
tridiagonal reductions are supported on CPU only.
## jaxlib 0.3.25
* Changes
* Added support for upper Hessenberg and tridiagonal reductions on CPU.
## jax 0.3.24 (Nov 4, 2022)
* Changes

View File

@ -205,7 +205,9 @@ Linear algebra operators (jax.lax.linalg)
cholesky
eig
eigh
hessenberg
lu
householder_product
qdwh
qr
schur

View File

@ -30,6 +30,7 @@ jax.scipy.linalg
expm
expm_frechet
funm
hessenberg
inv
lu
lu_factor

View File

@ -753,11 +753,11 @@ mlir.register_lowering(
platform='tpu')
triangular_solve_dtype_rule = partial(
_triangular_solve_dtype_rule = partial(
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
'triangular_solve')
def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
if a.ndim < 2:
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
@ -778,7 +778,7 @@ def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
raise TypeError(msg.format(a.shape, b.shape))
return b.shape
def triangular_solve_jvp_rule_a(
def _triangular_solve_jvp_rule_a(
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
m, n = b.shape[-2:]
@ -811,7 +811,7 @@ def triangular_solve_jvp_rule_a(
else:
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
def triangular_solve_transpose_rule(
def _triangular_solve_transpose_rule(
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal):
# Triangular solve is nonlinear in its first argument and linear in its second
@ -827,7 +827,7 @@ def triangular_solve_transpose_rule(
return [None, cotangent_b]
def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
lower, transpose_a, conjugate_a,
unit_diagonal):
x, y = batched_args
@ -856,13 +856,13 @@ def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
unit_diagonal=unit_diagonal), 0
triangular_solve_p = standard_primitive(
triangular_solve_shape_rule, triangular_solve_dtype_rule,
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
'triangular_solve')
ad.defjvp2(triangular_solve_p,
triangular_solve_jvp_rule_a,
_triangular_solve_jvp_rule_a,
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule
batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule
def _triangular_solve_lowering(
@ -1363,9 +1363,9 @@ mlir.register_lowering(
platform='rocm')
# orgqr: product of elementary Householder reflectors
# householder_product: product of elementary Householder reflectors
def orgqr(a: ArrayLike, taus: ArrayLike) -> Array:
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
"""Product of elementary Householder reflectors.
Args:
@ -1378,31 +1378,34 @@ def orgqr(a: ArrayLike, taus: ArrayLike) -> Array:
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
containing the products of the elementary Householder reflectors.
"""
return orgqr_p.bind(a, taus)
return householder_product_p.bind(a, taus)
def _orgqr_abstract_eval(a, taus):
def _householder_product_abstract_eval(a, taus):
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
raise NotImplementedError("Unsupported aval in orgqr_abstract_eval: "
raise NotImplementedError("Unsupported aval in householder_product_abstract_eval: "
f"{a.aval} {taus.aval}")
if a.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
raise ValueError("Argument to Householder product must have ndims >= 2")
*batch_dims, m, n = a.shape
*taus_batch_dims, k = taus.shape
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
raise ValueError(f"Type mismatch for orgqr: a={a} taus={taus}")
raise ValueError(f"Type mismatch for Householder product: a={a} taus={taus}")
if m < n:
raise ValueError("Householder product inputs must have at least as many "
f"rows as columns, got shape {a.shape}")
return a
def _orgqr_batching_rule(batched_args, batch_dims):
def _householder_product_batching_rule(batched_args, batch_dims):
a, taus = batched_args
b_a, b_taus, = batch_dims
return orgqr(batching.moveaxis(a, b_a, 0),
return householder_product(batching.moveaxis(a, b_a, 0),
batching.moveaxis(taus, b_taus, 0)), (0,)
def _orgqr_translation_rule(ctx, avals_in, avals_out, a, taus):
def _householder_product_translation_rule(ctx, avals_in, avals_out, a, taus):
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
a_aval, _ = ctx.avals_in
*batch_dims, m, n = a_aval.shape
@ -1420,22 +1423,23 @@ def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
return [a]
orgqr_p = Primitive('orgqr')
orgqr_p.def_impl(partial(xla.apply_primitive, orgqr_p))
orgqr_p.def_abstract_eval(_orgqr_abstract_eval)
batching.primitive_batchers[orgqr_p] = _orgqr_batching_rule
xla.register_translation(orgqr_p, _orgqr_translation_rule)
householder_product_p = Primitive('householder_product')
householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p))
householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
xla.register_translation(householder_product_p, _householder_product_translation_rule)
mlir.register_lowering(
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo),
platform='cpu')
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
platform='cuda')
mlir.register_lowering(
orgqr_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
platform='rocm')
@ -1492,13 +1496,13 @@ def _qr_lowering(a, *, full_matrices):
r, taus = geqrf(a)
if m < n:
q = orgqr(r[..., :m, :m], taus)
q = householder_product(r[..., :m, :m], taus)
elif full_matrices:
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
q = lax.pad(r, lax_internal._zero(r), pads)
q = orgqr(q, taus)
q = householder_product(q, taus)
else:
q = orgqr(r, taus)
q = householder_product(r, taus)
r = r[..., :n, :n]
r = jnp.triu(r)
return q, r
@ -1920,6 +1924,146 @@ batching.primitive_batchers[schur_p] = _schur_batching_rule
ad.primitive_jvps[schur_p] = _schur_jvp_rule
# hessenberg: Upper Hessenberg reduction
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
"""Reduces a square matrix to upper Hessenberg form.
Args:
a: A floating point or complex square matrix or batch of matrices.
Returns:
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of ``a``
contain the upper Hessenberg matrix, and the elements below the first
subdiagonal contain the Householder reflectors. For each Householder
reflector ``taus`` contains the scalar factors of the elementary Householder
reflectors.
"""
return hessenberg_p.bind(a)
def _hessenberg_abstract_eval(a):
if a.ndim < 2:
msg = "hessenberg requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
if a.shape[-1] != a.shape[-2]:
msg = ("hessenberg requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
hessenberg_p = Primitive("hessenberg")
hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p))
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
hessenberg_p.multiple_results = True
def _hessenberg_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return hessenberg(x), 0
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
def _hessenberg_cpu_mhlo(ctx, a):
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
if not hasattr(lapack, "gehrd_mhlo"):
raise RuntimeError("Hessenberg reduction on CPU requires jaxlib 0.3.25 or "
"newer")
a_aval, = ctx.avals_in
batch_dims = a_aval.shape[:-2]
a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a)
ok = mlir.compare_mhlo(
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
return [
_broadcasting_select_mhlo(
mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
a, _nan_like_mhlo(ctx.avals_out[0])),
_broadcasting_select_mhlo(
mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1,),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
taus, _nan_like_mhlo(ctx.avals_out[1])),
]
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
# tridiagonal: Upper Hessenberg reduction
def tridiagonal(a: ArrayLike, *, lower=True) -> Tuple[Array, Array]:
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
Args:
a: A floating point or complex matrix or batch of matrices.
lower: Describes which triangle of the input matrices to use.
The other triangle is ignored and not accessed.
Returns:
A ``(a, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
and elements below the first subdiagonal contain the elementary Householder
reflectors. If ``lower=False`` the diagonal and first superdiagonal of the
matrix contains the tridiagonal representation, and elements above the first
superdiagonal contain the elementary Householder reflectors.
``taus`` contains the scalar factors of the elementary Householder
reflectors.
"""
arr, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
nan = arr.dtype.type(jnp.nan)
if jnp.issubdtype(arr.dtype, np.complexfloating):
nan = nan + arr.dtype.type(jnp.nan * 1j)
arr = jnp.where((info == 0)[..., None, None], arr, nan)
taus = jnp.where((info == 0)[..., None], taus, nan)
return arr, taus
def _tridiagonal_abstract_eval(a, *, lower):
if a.ndim < 2:
msg = "tridiagonal requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
if a.shape[-1] != a.shape[-2]:
msg = ("tridiagonal requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
if a.shape[-1] == 0:
msg = ("tridiagonal requires the last two dimensions of a to be non-zero, "
"got a.shape of {}.")
raise TypeError(msg.format(a.shape))
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
ShapedArray(a.shape[:-2], np.int32)]
tridiagonal_p = Primitive("tridiagonal")
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
tridiagonal_p.multiple_results = True
def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return tridiagonal(x), 0
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
def _tridiagonal_cpu_mhlo(ctx, a, *, lower):
a_aval, = ctx.avals_in
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
if not hasattr(lapack, "sytrd_mhlo"):
raise RuntimeError("Tridiagonal reduction on CPU requires jaxlib 0.3.25 or "
"newer")
a, d, e, taus, info = lapack.sytrd_mhlo(a_aval.dtype, a, lower=lower)
del d, e
return a, taus, info
mlir.register_lowering(tridiagonal_p, _tridiagonal_cpu_mhlo, platform='cpu')
# Utilities
def _nan_like_mhlo(aval):

View File

@ -1001,3 +1001,34 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Arra
return T, Z
return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z))
@overload
def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = False,
check_finite: bool = True) -> Array: ...
@overload
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
check_finite: bool = True) -> Tuple[Array, Array]: ...
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
del overwrite_a, check_finite
n = jnp.shape(a)[-1]
if n == 0:
if calc_q:
return jnp.zeros_like(a), jnp.zeros_like(a)
else:
return jnp.zeros_like(a)
a_out, taus = lax_linalg.hessenberg(a)
h = jnp.triu(a_out, -1)
if calc_q:
q = lax_linalg.householder_product(a_out[..., 1:, :-1], taus)
batch_dims = a_out.shape[:-2]
q = jnp.block([[jnp.ones(batch_dims + (1, 1), dtype=a_out.dtype),
jnp.zeros(batch_dims + (1, n - 1), dtype=a_out.dtype)],
[jnp.zeros(batch_dims + (n - 1, 1), dtype=a_out.dtype), q]])
return h, q
else:
return h

View File

@ -1253,7 +1253,9 @@ tf_not_yet_impl = [
"lu_pivots_to_permutation",
"xla_pmap",
"geqrf",
"orgqr",
"householder_product",
"hessenberg",
"tridiagonal",
"eigh_jacobi",
]

View File

@ -19,15 +19,21 @@ from jax._src.lax.linalg import (
eig_p,
eigh,
eigh_p,
hessenberg,
hessenberg_p,
lu,
lu_p,
lu_pivots_to_permutation,
householder_product,
householder_product_p,
qr,
qr_p,
svd,
svd_p,
triangular_solve,
triangular_solve_p,
tridiagonal,
tridiagonal_p,
tridiagonal_solve,
tridiagonal_solve_p,
schur,

View File

@ -22,6 +22,7 @@ from jax._src.scipy.linalg import (
eigh_tridiagonal as eigh_tridiagonal,
expm as expm,
expm_frechet as expm_frechet,
hessenberg as hessenberg,
inv as inv,
lu as lu,
lu_factor as lu_factor,

View File

@ -15,8 +15,8 @@ limitations under the License.
#include <complex>
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/cpu/lapack_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h"
namespace jax {
@ -127,6 +127,27 @@ void GetLapackKernelsFromScipy() {
ComplexGees<std::complex<double>>::fn =
reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>(
lapack_ptr("zgees"));
Gehrd<float>::fn =
reinterpret_cast<Gehrd<float>::FnType*>(lapack_ptr("sgehrd"));
Gehrd<double>::fn =
reinterpret_cast<Gehrd<double>::FnType*>(lapack_ptr("dgehrd"));
Gehrd<std::complex<float>>::fn =
reinterpret_cast<Gehrd<std::complex<float>>::FnType*>(
lapack_ptr("cgehrd"));
Gehrd<std::complex<double>>::fn =
reinterpret_cast<Gehrd<std::complex<double>>::FnType*>(
lapack_ptr("zgehrd"));
Sytrd<float>::fn =
reinterpret_cast<Sytrd<float>::FnType*>(lapack_ptr("ssytrd"));
Sytrd<double>::fn =
reinterpret_cast<Sytrd<double>::FnType*>(lapack_ptr("dsytrd"));
Sytrd<std::complex<float>>::fn =
reinterpret_cast<Sytrd<std::complex<float>>::FnType*>(
lapack_ptr("chetrd"));
Sytrd<std::complex<double>>::fn =
reinterpret_cast<Sytrd<std::complex<double>>::FnType*>(
lapack_ptr("zhetrd"));
initialized = true;
}
@ -185,6 +206,21 @@ py::dict Registrations() {
EncapsulateFunction(ComplexGees<std::complex<float>>::Kernel);
dict["lapack_zgees"] =
EncapsulateFunction(ComplexGees<std::complex<double>>::Kernel);
dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd<float>::Kernel);
dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd<double>::Kernel);
dict["lapack_cgehrd"] =
EncapsulateFunction(Gehrd<std::complex<float>>::Kernel);
dict["lapack_zgehrd"] =
EncapsulateFunction(Gehrd<std::complex<double>>::Kernel);
dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd<float>::Kernel);
dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd<double>::Kernel);
dict["lapack_chetrd"] =
EncapsulateFunction(Sytrd<std::complex<float>>::Kernel);
dict["lapack_zhetrd"] =
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
return dict;
}
@ -211,6 +247,14 @@ PYBIND11_MODULE(_lapack, m) {
m.def("syevd_iwork_size", &SyevdIworkSize);
m.def("heevd_work_size", &HeevdWorkSize);
m.def("heevd_rwork_size", &HeevdRworkSize);
m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace);
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace);
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace);
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace);
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace);
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace);
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace);
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace);
}
} // namespace

View File

@ -775,4 +775,111 @@ template struct RealGees<double>;
template struct ComplexGees<std::complex<float>>;
template struct ComplexGees<std::complex<double>>;
template <typename T>
typename Gehrd<T>::FnType* Gehrd<T>::fn = nullptr;
template <typename T>
void Gehrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
int32_t ilo = *reinterpret_cast<int32_t*>(data[1]);
int32_t ihi = *reinterpret_cast<int32_t*>(data[2]);
int32_t lda = *reinterpret_cast<int32_t*>(data[3]);
int32_t batch = *reinterpret_cast<int32_t*>(data[4]);
int32_t lwork = *reinterpret_cast<int32_t*>(data[5]);
T* a = reinterpret_cast<T*>(data[6]);
void** out = reinterpret_cast<void**>(out_tuple);
T* a_out = reinterpret_cast<T*>(out[0]);
T* tau = reinterpret_cast<T*>(out[1]);
int* info = reinterpret_cast<int*>(out[2]);
T* work = reinterpret_cast<T*>(out[3]);
if (a_out != a) {
std::memcpy(a_out, a,
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
static_cast<int64_t>(n) * sizeof(T));
}
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
for (int i = 0; i < batch; ++i) {
fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info);
a_out += a_plus;
tau += n - 1;
++info;
}
}
template <typename T>
int64_t Gehrd<T>::Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
lapack_int ihi) {
T work = 0;
lapack_int lwork = -1;
lapack_int info = 0;
fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info);
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
}
template struct Gehrd<float>;
template struct Gehrd<double>;
template struct Gehrd<std::complex<float>>;
template struct Gehrd<std::complex<double>>;
template <typename T>
typename Sytrd<T>::FnType* Sytrd<T>::fn = nullptr;
template <typename T>
void Sytrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
int32_t lda = *reinterpret_cast<int32_t*>(data[2]);
int32_t batch = *reinterpret_cast<int32_t*>(data[3]);
int32_t lwork = *reinterpret_cast<int32_t*>(data[4]);
T* a = reinterpret_cast<T*>(data[5]);
void** out = reinterpret_cast<void**>(out_tuple);
T* a_out = reinterpret_cast<T*>(out[0]);
typedef typename real_type<T>::type Real;
Real* d = reinterpret_cast<Real*>(out[1]);
Real* e = reinterpret_cast<Real*>(out[2]);
T* tau = reinterpret_cast<T*>(out[3]);
int* info = reinterpret_cast<int*>(out[4]);
T* work = reinterpret_cast<T*>(out[5]);
if (a_out != a) {
std::memcpy(a_out, a,
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
static_cast<int64_t>(n) * sizeof(T));
}
char cuplo = lower ? 'L' : 'U';
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
for (int i = 0; i < batch; ++i) {
fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info);
a_out += a_plus;
d += n;
e += n - 1;
tau += n - 1;
++info;
}
}
template <typename T>
int64_t Sytrd<T>::Workspace(lapack_int lda, lapack_int n) {
char cuplo = 'L';
T work = 0;
lapack_int lwork = -1;
lapack_int info = 0;
fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork,
&info);
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
}
template struct Sytrd<float>;
template struct Sytrd<double>;
template struct Sytrd<std::complex<float>>;
template struct Sytrd<std::complex<double>>;
} // namespace jax

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <complex>
#include <cstdint>
#include "tensorflow/compiler/xla/service/custom_call_status.h"
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
@ -88,8 +89,8 @@ struct RealGesdd {
static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int m, lapack_int n,
bool job_opt_compute_uv, bool job_opt_full_matrices);
static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
bool job_opt_full_matrices);
};
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv);
@ -104,11 +105,10 @@ struct ComplexGesdd {
static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int m, lapack_int n,
bool job_opt_compute_uv, bool job_opt_full_matrices);
static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
bool job_opt_full_matrices);
};
lapack_int SyevdWorkSize(int64_t n);
lapack_int SyevdIworkSize(int64_t n);
@ -176,6 +176,45 @@ struct ComplexGees {
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
};
// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form.
template <typename T>
struct Gehrd {
using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a,
lapack_int* lda, T* tau, T* work, lapack_int* lwork,
lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
lapack_int ihi);
};
template <typename T>
struct real_type {
typedef T type;
};
template <typename T>
struct real_type<std::complex<T>> {
typedef T type;
};
// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal
// form.
template <typename T>
struct Sytrd {
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
typename real_type<T>::type* d,
typename real_type<T>::type* e,
T* tau, T* work,
lapack_int* lwork, lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int lda, lapack_int n);
};
} // namespace jax
#endif // JAXLIB_CPU_LAPACK_KERNELS_H_

View File

@ -66,6 +66,16 @@ jax::RealGees<double>::FnType dgees_;
jax::ComplexGees<std::complex<float>>::FnType cgees_;
jax::ComplexGees<std::complex<double>>::FnType zgees_;
jax::Gehrd<float>::FnType sgehrd_;
jax::Gehrd<double>::FnType dgehrd_;
jax::Gehrd<std::complex<float>>::FnType cgehrd_;
jax::Gehrd<std::complex<double>>::FnType zgehrd_;
jax::Sytrd<float>::FnType ssytrd_;
jax::Sytrd<double>::FnType dsytrd_;
jax::Sytrd<std::complex<float>>::FnType chetrd_;
jax::Sytrd<std::complex<double>>::FnType zhetrd_;
} // extern "C"
namespace jax {
@ -107,6 +117,15 @@ static auto init = []() -> int {
RealGees<double>::fn = dgees_;
ComplexGees<std::complex<float>>::fn = cgees_;
ComplexGees<std::complex<double>>::fn = zgees_;
Gehrd<float>::fn = sgehrd_;
Gehrd<double>::fn = dgehrd_;
Gehrd<std::complex<float>>::fn = cgehrd_;
Gehrd<std::complex<double>>::fn = zgehrd_;
Sytrd<float>::fn = ssytrd_;
Sytrd<double>::fn = dsytrd_;
Sytrd<std::complex<float>>::fn = chetrd_;
Sytrd<std::complex<double>>::fn = zhetrd_;
return 0;
}();

View File

@ -628,3 +628,117 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
return (out[0], out[3], out[4], out[5])
else:
return (out[0], out[3], out[5])
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
def gehrd_mhlo(dtype, a):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgehrd"
lwork = _lapack.lapack_sgehrd_workspace(n, n, 1, n)
elif dtype == np.float64:
fn = b"lapack_dgehrd"
lwork = _lapack.lapack_dgehrd_workspace(n, n, 1, n)
elif dtype == np.complex64:
fn = b"lapack_cgehrd"
lwork = _lapack.lapack_cgehrd_workspace(n, n, 1, n)
elif dtype == np.complex128:
fn = b"lapack_zgehrd"
lwork = _lapack.lapack_zgehrd_workspace(n, n, 1, n)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(n), _mhlo_s32(1), _mhlo_s32(n), _mhlo_s32(n), _mhlo_s32(b),
_mhlo_s32(lwork), a],
operand_layouts=[[]] * 6 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={6: 0},
)
return out[:3]
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
def sytrd_mhlo(dtype, a, *, lower):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_ssytrd"
lwork = _lapack.lapack_ssytrd_workspace(n, n)
diag_type = a_type.element_type
elif dtype == np.float64:
fn = b"lapack_dsytrd"
lwork = _lapack.lapack_dsytrd_workspace(n, n)
diag_type = a_type.element_type
elif dtype == np.complex64:
fn = b"lapack_chetrd"
lwork = _lapack.lapack_chetrd_workspace(n, n)
diag_type = ir.ComplexType.get(ir.F32Type.get())
elif dtype == np.complex128:
fn = b"lapack_zhetrd"
lwork = _lapack.lapack_zhetrd_workspace(n, n)
diag_type = ir.ComplexType.get(ir.F64Type.get())
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(n), _mhlo_s32(1 if lower else 0), _mhlo_s32(max(1, n)),
_mhlo_s32(b), _mhlo_s32(lwork), a],
operand_layouts=[[]] * 5 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={5: 0},
)
return out[:5]

View File

@ -15,6 +15,7 @@
"""Tests for the LAPAX linear algebra module."""
from functools import partial
import unittest
import numpy as np
import scipy
@ -28,6 +29,7 @@ from jax import jit, grad, jvp, vmap
from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax._src.lib import version as jaxlib_version
from jax._src import test_util as jtu
from jax.config import config
@ -1244,6 +1246,96 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(jsp_func, args_maker)
@jtu.sample_product(
[dict(shape=shape, k=k)
for shape in [(1, 1), (3, 4, 4), (10, 5)]
# TODO(phawkins): there are some test failures on GPU for k=0
for k in range(1, shape[-1] + 1)],
dtype=float_types + complex_types,
)
def testHouseholderProduct(self, shape, k, dtype):
@partial(np.vectorize, signature='(m,n),(k)->(m,n)')
def reference_fn(a, taus):
if dtype == np.float32:
q, _, info = scipy.linalg.lapack.sorgqr(a, taus)
elif dtype == np.float64:
q, _, info = scipy.linalg.lapack.dorgqr(a, taus)
elif dtype == np.complex64:
q, _, info = scipy.linalg.lapack.cungqr(a, taus)
elif dtype == np.complex128:
q, _, info = scipy.linalg.lapack.zungqr(a, taus)
else:
assert False, dtype
assert info == 0, info
return q
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(shape[:-2] + (k,), dtype)]
tol = {np.float32: 1e-5, np.complex64: 1e-5, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(reference_fn, lax.linalg.householder_product,
args_maker, rtol=tol, atol=tol)
self._CompileAndCheck(lax.linalg.householder_product, args_maker)
@jtu.sample_product(
shape=[(1, 1), (2, 4, 4), (0, 100, 100), (10, 10)],
dtype=float_types + complex_types,
calc_q=[False, True],
)
@jtu.skip_on_devices("gpu", "tpu")
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
def testHessenberg(self, shape, dtype, calc_q):
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
if calc_q:
sp_func = np.vectorize(partial(scipy.linalg.hessenberg, calc_q=True),
otypes=(dtype, dtype),
signature='(n,n)->(n,n),(n,n)')
else:
sp_func = np.vectorize(scipy.linalg.hessenberg, signature='(n,n)->(n,n)',
otypes=(dtype,))
args_maker = lambda: [rng(shape, dtype)]
# scipy.linalg.hessenberg sometimes returns a float Q matrix for complex
# inputs
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1e-5, atol=1e-5,
check_dtypes=not calc_q)
self._CompileAndCheck(jsp_func, args_maker)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (10, 10)],
dtype=float_types + complex_types,
lower=[False, True],
)
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
@jtu.skip_on_devices("gpu", "tpu")
def testTridiagonal(self, shape, dtype, lower):
rng = jtu.rand_default(self.rng())
def jax_func(a):
return lax.linalg.tridiagonal(a, lower=lower)
@partial(np.vectorize, otypes=(dtype, dtype),
signature='(n,n)->(n,n),(k)')
def sp_func(a):
if dtype == np.float32:
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
elif dtype == np.float64:
c, d, e, tau, info = scipy.linalg.lapack.dsytrd(a, lower=lower)
elif dtype == np.complex64:
c, d, e, tau, info = scipy.linalg.lapack.chetrd(a, lower=lower)
elif dtype == np.complex128:
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
else:
assert False, dtype
del d, e
assert info == 0
return c, tau
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-5, atol=1e-5,
check_dtypes=False)
@jtu.sample_product(
n=[1, 4, 5, 20, 50, 100],
dtype=float_types + complex_types,