Add support for Hessenberg and tridiagonal matrix reductions on CPU.

* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
This commit is contained in:
Peter Hawkins 2022-11-09 06:23:22 -08:00 committed by jax authors
parent 30637d052b
commit 1cead779a3
14 changed files with 650 additions and 40 deletions

View File

@ -9,8 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.3.25 ## jax 0.3.25
* Changes * Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option. * {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
* {func}`jax.scipy.linalg.hessenberg` is now supported on CPU only. Requires
jaxlib > 0.3.24.
* New functions {func}`jax.lax.linalg.hessenberg`,
{func}`jax.lax.linalg.tridiagonal`, and
{func}`jax.lax.linalg.householder_product` were added. Householder and
tridiagonal reductions are supported on CPU only.
## jaxlib 0.3.25 ## jaxlib 0.3.25
* Changes
* Added support for upper Hessenberg and tridiagonal reductions on CPU.
## jax 0.3.24 (Nov 4, 2022) ## jax 0.3.24 (Nov 4, 2022)
* Changes * Changes

View File

@ -205,7 +205,9 @@ Linear algebra operators (jax.lax.linalg)
cholesky cholesky
eig eig
eigh eigh
hessenberg
lu lu
householder_product
qdwh qdwh
qr qr
schur schur

View File

@ -30,6 +30,7 @@ jax.scipy.linalg
expm expm
expm_frechet expm_frechet
funm funm
hessenberg
inv inv
lu lu
lu_factor lu_factor

View File

@ -753,11 +753,11 @@ mlir.register_lowering(
platform='tpu') platform='tpu')
triangular_solve_dtype_rule = partial( _triangular_solve_dtype_rule = partial(
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex), naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
'triangular_solve') 'triangular_solve')
def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs): def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
if a.ndim < 2: if a.ndim < 2:
msg = "triangular_solve requires a.ndim to be at least 2, got {}." msg = "triangular_solve requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim)) raise TypeError(msg.format(a.ndim))
@ -778,7 +778,7 @@ def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
raise TypeError(msg.format(a.shape, b.shape)) raise TypeError(msg.format(a.shape, b.shape))
return b.shape return b.shape
def triangular_solve_jvp_rule_a( def _triangular_solve_jvp_rule_a(
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a, g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal): unit_diagonal):
m, n = b.shape[-2:] m, n = b.shape[-2:]
@ -811,7 +811,7 @@ def triangular_solve_jvp_rule_a(
else: else:
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1}) return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
def triangular_solve_transpose_rule( def _triangular_solve_transpose_rule(
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a, cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
unit_diagonal): unit_diagonal):
# Triangular solve is nonlinear in its first argument and linear in its second # Triangular solve is nonlinear in its first argument and linear in its second
@ -827,7 +827,7 @@ def triangular_solve_transpose_rule(
return [None, cotangent_b] return [None, cotangent_b]
def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side, def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
lower, transpose_a, conjugate_a, lower, transpose_a, conjugate_a,
unit_diagonal): unit_diagonal):
x, y = batched_args x, y = batched_args
@ -856,13 +856,13 @@ def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
unit_diagonal=unit_diagonal), 0 unit_diagonal=unit_diagonal), 0
triangular_solve_p = standard_primitive( triangular_solve_p = standard_primitive(
triangular_solve_shape_rule, triangular_solve_dtype_rule, _triangular_solve_shape_rule, _triangular_solve_dtype_rule,
'triangular_solve') 'triangular_solve')
ad.defjvp2(triangular_solve_p, ad.defjvp2(triangular_solve_p,
triangular_solve_jvp_rule_a, _triangular_solve_jvp_rule_a,
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws)) lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule
def _triangular_solve_lowering( def _triangular_solve_lowering(
@ -1363,9 +1363,9 @@ mlir.register_lowering(
platform='rocm') platform='rocm')
# orgqr: product of elementary Householder reflectors # householder_product: product of elementary Householder reflectors
def orgqr(a: ArrayLike, taus: ArrayLike) -> Array: def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
"""Product of elementary Householder reflectors. """Product of elementary Householder reflectors.
Args: Args:
@ -1378,31 +1378,34 @@ def orgqr(a: ArrayLike, taus: ArrayLike) -> Array:
A batch of orthogonal (unitary) matrices with the same shape as ``a``, A batch of orthogonal (unitary) matrices with the same shape as ``a``,
containing the products of the elementary Householder reflectors. containing the products of the elementary Householder reflectors.
""" """
return orgqr_p.bind(a, taus) return householder_product_p.bind(a, taus)
def _orgqr_abstract_eval(a, taus): def _householder_product_abstract_eval(a, taus):
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray): if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
raise NotImplementedError("Unsupported aval in orgqr_abstract_eval: " raise NotImplementedError("Unsupported aval in householder_product_abstract_eval: "
f"{a.aval} {taus.aval}") f"{a.aval} {taus.aval}")
if a.ndim < 2: if a.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2") raise ValueError("Argument to Householder product must have ndims >= 2")
*batch_dims, m, n = a.shape *batch_dims, m, n = a.shape
*taus_batch_dims, k = taus.shape *taus_batch_dims, k = taus.shape
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n): if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
raise ValueError(f"Type mismatch for orgqr: a={a} taus={taus}") raise ValueError(f"Type mismatch for Householder product: a={a} taus={taus}")
if m < n:
raise ValueError("Householder product inputs must have at least as many "
f"rows as columns, got shape {a.shape}")
return a return a
def _orgqr_batching_rule(batched_args, batch_dims): def _householder_product_batching_rule(batched_args, batch_dims):
a, taus = batched_args a, taus = batched_args
b_a, b_taus, = batch_dims b_a, b_taus, = batch_dims
return orgqr(batching.moveaxis(a, b_a, 0), return householder_product(batching.moveaxis(a, b_a, 0),
batching.moveaxis(taus, b_taus, 0)), (0,) batching.moveaxis(taus, b_taus, 0)), (0,)
def _orgqr_translation_rule(ctx, avals_in, avals_out, a, taus): def _householder_product_translation_rule(ctx, avals_in, avals_out, a, taus):
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)] return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus): def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
a_aval, _ = ctx.avals_in a_aval, _ = ctx.avals_in
*batch_dims, m, n = a_aval.shape *batch_dims, m, n = a_aval.shape
@ -1420,22 +1423,23 @@ def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
return [a] return [a]
orgqr_p = Primitive('orgqr') householder_product_p = Primitive('householder_product')
orgqr_p.def_impl(partial(xla.apply_primitive, orgqr_p)) householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p))
orgqr_p.def_abstract_eval(_orgqr_abstract_eval) householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
batching.primitive_batchers[orgqr_p] = _orgqr_batching_rule batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
xla.register_translation(orgqr_p, _orgqr_translation_rule) xla.register_translation(householder_product_p, _householder_product_translation_rule)
mlir.register_lowering( mlir.register_lowering(
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo), householder_product_p,
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo),
platform='cpu') platform='cpu')
mlir.register_lowering( mlir.register_lowering(
orgqr_p, householder_product_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr), partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
platform='cuda') platform='cuda')
mlir.register_lowering( mlir.register_lowering(
orgqr_p, householder_product_p,
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr), partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
platform='rocm') platform='rocm')
@ -1492,13 +1496,13 @@ def _qr_lowering(a, *, full_matrices):
r, taus = geqrf(a) r, taus = geqrf(a)
if m < n: if m < n:
q = orgqr(r[..., :m, :m], taus) q = householder_product(r[..., :m, :m], taus)
elif full_matrices: elif full_matrices:
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)] pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
q = lax.pad(r, lax_internal._zero(r), pads) q = lax.pad(r, lax_internal._zero(r), pads)
q = orgqr(q, taus) q = householder_product(q, taus)
else: else:
q = orgqr(r, taus) q = householder_product(r, taus)
r = r[..., :n, :n] r = r[..., :n, :n]
r = jnp.triu(r) r = jnp.triu(r)
return q, r return q, r
@ -1920,6 +1924,146 @@ batching.primitive_batchers[schur_p] = _schur_batching_rule
ad.primitive_jvps[schur_p] = _schur_jvp_rule ad.primitive_jvps[schur_p] = _schur_jvp_rule
# hessenberg: Upper Hessenberg reduction
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
"""Reduces a square matrix to upper Hessenberg form.
Args:
a: A floating point or complex square matrix or batch of matrices.
Returns:
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of ``a``
contain the upper Hessenberg matrix, and the elements below the first
subdiagonal contain the Householder reflectors. For each Householder
reflector ``taus`` contains the scalar factors of the elementary Householder
reflectors.
"""
return hessenberg_p.bind(a)
def _hessenberg_abstract_eval(a):
if a.ndim < 2:
msg = "hessenberg requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
if a.shape[-1] != a.shape[-2]:
msg = ("hessenberg requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
hessenberg_p = Primitive("hessenberg")
hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p))
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
hessenberg_p.multiple_results = True
def _hessenberg_batching_rule(batched_args, batch_dims):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return hessenberg(x), 0
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
def _hessenberg_cpu_mhlo(ctx, a):
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
if not hasattr(lapack, "gehrd_mhlo"):
raise RuntimeError("Hessenberg reduction on CPU requires jaxlib 0.3.25 or "
"newer")
a_aval, = ctx.avals_in
batch_dims = a_aval.shape[:-2]
a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a)
ok = mlir.compare_mhlo(
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
return [
_broadcasting_select_mhlo(
mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1, 1),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
a, _nan_like_mhlo(ctx.avals_out[0])),
_broadcasting_select_mhlo(
mhlo.BroadcastInDimOp(
ir.RankedTensorType.get(batch_dims + (1,),
ir.IntegerType.get_signless(1)),
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
taus, _nan_like_mhlo(ctx.avals_out[1])),
]
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
# tridiagonal: Upper Hessenberg reduction
def tridiagonal(a: ArrayLike, *, lower=True) -> Tuple[Array, Array]:
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
Args:
a: A floating point or complex matrix or batch of matrices.
lower: Describes which triangle of the input matrices to use.
The other triangle is ignored and not accessed.
Returns:
A ``(a, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
and elements below the first subdiagonal contain the elementary Householder
reflectors. If ``lower=False`` the diagonal and first superdiagonal of the
matrix contains the tridiagonal representation, and elements above the first
superdiagonal contain the elementary Householder reflectors.
``taus`` contains the scalar factors of the elementary Householder
reflectors.
"""
arr, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
nan = arr.dtype.type(jnp.nan)
if jnp.issubdtype(arr.dtype, np.complexfloating):
nan = nan + arr.dtype.type(jnp.nan * 1j)
arr = jnp.where((info == 0)[..., None, None], arr, nan)
taus = jnp.where((info == 0)[..., None], taus, nan)
return arr, taus
def _tridiagonal_abstract_eval(a, *, lower):
if a.ndim < 2:
msg = "tridiagonal requires a.ndim to be at least 2, got {}."
raise TypeError(msg.format(a.ndim))
if a.shape[-1] != a.shape[-2]:
msg = ("tridiagonal requires the last two dimensions of a to be equal "
"in size, got a.shape of {}.")
raise TypeError(msg.format(a.shape))
if a.shape[-1] == 0:
msg = ("tridiagonal requires the last two dimensions of a to be non-zero, "
"got a.shape of {}.")
raise TypeError(msg.format(a.shape))
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
ShapedArray(a.shape[:-2], np.int32)]
tridiagonal_p = Primitive("tridiagonal")
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
tridiagonal_p.multiple_results = True
def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
return tridiagonal(x), 0
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
def _tridiagonal_cpu_mhlo(ctx, a, *, lower):
a_aval, = ctx.avals_in
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
if not hasattr(lapack, "sytrd_mhlo"):
raise RuntimeError("Tridiagonal reduction on CPU requires jaxlib 0.3.25 or "
"newer")
a, d, e, taus, info = lapack.sytrd_mhlo(a_aval.dtype, a, lower=lower)
del d, e
return a, taus, info
mlir.register_lowering(tridiagonal_p, _tridiagonal_cpu_mhlo, platform='cpu')
# Utilities # Utilities
def _nan_like_mhlo(aval): def _nan_like_mhlo(aval):

View File

@ -1001,3 +1001,34 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Arra
return T, Z return T, Z
return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z)) return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z))
@overload
def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = False,
check_finite: bool = True) -> Array: ...
@overload
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
check_finite: bool = True) -> Tuple[Array, Array]: ...
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
del overwrite_a, check_finite
n = jnp.shape(a)[-1]
if n == 0:
if calc_q:
return jnp.zeros_like(a), jnp.zeros_like(a)
else:
return jnp.zeros_like(a)
a_out, taus = lax_linalg.hessenberg(a)
h = jnp.triu(a_out, -1)
if calc_q:
q = lax_linalg.householder_product(a_out[..., 1:, :-1], taus)
batch_dims = a_out.shape[:-2]
q = jnp.block([[jnp.ones(batch_dims + (1, 1), dtype=a_out.dtype),
jnp.zeros(batch_dims + (1, n - 1), dtype=a_out.dtype)],
[jnp.zeros(batch_dims + (n - 1, 1), dtype=a_out.dtype), q]])
return h, q
else:
return h

View File

@ -1253,7 +1253,9 @@ tf_not_yet_impl = [
"lu_pivots_to_permutation", "lu_pivots_to_permutation",
"xla_pmap", "xla_pmap",
"geqrf", "geqrf",
"orgqr", "householder_product",
"hessenberg",
"tridiagonal",
"eigh_jacobi", "eigh_jacobi",
] ]

View File

@ -19,15 +19,21 @@ from jax._src.lax.linalg import (
eig_p, eig_p,
eigh, eigh,
eigh_p, eigh_p,
hessenberg,
hessenberg_p,
lu, lu,
lu_p, lu_p,
lu_pivots_to_permutation, lu_pivots_to_permutation,
householder_product,
householder_product_p,
qr, qr,
qr_p, qr_p,
svd, svd,
svd_p, svd_p,
triangular_solve, triangular_solve,
triangular_solve_p, triangular_solve_p,
tridiagonal,
tridiagonal_p,
tridiagonal_solve, tridiagonal_solve,
tridiagonal_solve_p, tridiagonal_solve_p,
schur, schur,

View File

@ -22,6 +22,7 @@ from jax._src.scipy.linalg import (
eigh_tridiagonal as eigh_tridiagonal, eigh_tridiagonal as eigh_tridiagonal,
expm as expm, expm as expm,
expm_frechet as expm_frechet, expm_frechet as expm_frechet,
hessenberg as hessenberg,
inv as inv, inv as inv,
lu as lu, lu as lu,
lu_factor as lu_factor, lu_factor as lu_factor,

View File

@ -15,8 +15,8 @@ limitations under the License.
#include <complex> #include <complex>
#include "jaxlib/kernel_pybind11_helpers.h"
#include "jaxlib/cpu/lapack_kernels.h" #include "jaxlib/cpu/lapack_kernels.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h" #include "include/pybind11/pybind11.h"
namespace jax { namespace jax {
@ -127,6 +127,27 @@ void GetLapackKernelsFromScipy() {
ComplexGees<std::complex<double>>::fn = ComplexGees<std::complex<double>>::fn =
reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>( reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>(
lapack_ptr("zgees")); lapack_ptr("zgees"));
Gehrd<float>::fn =
reinterpret_cast<Gehrd<float>::FnType*>(lapack_ptr("sgehrd"));
Gehrd<double>::fn =
reinterpret_cast<Gehrd<double>::FnType*>(lapack_ptr("dgehrd"));
Gehrd<std::complex<float>>::fn =
reinterpret_cast<Gehrd<std::complex<float>>::FnType*>(
lapack_ptr("cgehrd"));
Gehrd<std::complex<double>>::fn =
reinterpret_cast<Gehrd<std::complex<double>>::FnType*>(
lapack_ptr("zgehrd"));
Sytrd<float>::fn =
reinterpret_cast<Sytrd<float>::FnType*>(lapack_ptr("ssytrd"));
Sytrd<double>::fn =
reinterpret_cast<Sytrd<double>::FnType*>(lapack_ptr("dsytrd"));
Sytrd<std::complex<float>>::fn =
reinterpret_cast<Sytrd<std::complex<float>>::FnType*>(
lapack_ptr("chetrd"));
Sytrd<std::complex<double>>::fn =
reinterpret_cast<Sytrd<std::complex<double>>::FnType*>(
lapack_ptr("zhetrd"));
initialized = true; initialized = true;
} }
@ -185,6 +206,21 @@ py::dict Registrations() {
EncapsulateFunction(ComplexGees<std::complex<float>>::Kernel); EncapsulateFunction(ComplexGees<std::complex<float>>::Kernel);
dict["lapack_zgees"] = dict["lapack_zgees"] =
EncapsulateFunction(ComplexGees<std::complex<double>>::Kernel); EncapsulateFunction(ComplexGees<std::complex<double>>::Kernel);
dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd<float>::Kernel);
dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd<double>::Kernel);
dict["lapack_cgehrd"] =
EncapsulateFunction(Gehrd<std::complex<float>>::Kernel);
dict["lapack_zgehrd"] =
EncapsulateFunction(Gehrd<std::complex<double>>::Kernel);
dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd<float>::Kernel);
dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd<double>::Kernel);
dict["lapack_chetrd"] =
EncapsulateFunction(Sytrd<std::complex<float>>::Kernel);
dict["lapack_zhetrd"] =
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
return dict; return dict;
} }
@ -211,6 +247,14 @@ PYBIND11_MODULE(_lapack, m) {
m.def("syevd_iwork_size", &SyevdIworkSize); m.def("syevd_iwork_size", &SyevdIworkSize);
m.def("heevd_work_size", &HeevdWorkSize); m.def("heevd_work_size", &HeevdWorkSize);
m.def("heevd_rwork_size", &HeevdRworkSize); m.def("heevd_rwork_size", &HeevdRworkSize);
m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace);
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace);
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace);
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace);
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace);
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace);
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace);
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace);
} }
} // namespace } // namespace

View File

@ -775,4 +775,111 @@ template struct RealGees<double>;
template struct ComplexGees<std::complex<float>>; template struct ComplexGees<std::complex<float>>;
template struct ComplexGees<std::complex<double>>; template struct ComplexGees<std::complex<double>>;
template <typename T>
typename Gehrd<T>::FnType* Gehrd<T>::fn = nullptr;
template <typename T>
void Gehrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
int32_t ilo = *reinterpret_cast<int32_t*>(data[1]);
int32_t ihi = *reinterpret_cast<int32_t*>(data[2]);
int32_t lda = *reinterpret_cast<int32_t*>(data[3]);
int32_t batch = *reinterpret_cast<int32_t*>(data[4]);
int32_t lwork = *reinterpret_cast<int32_t*>(data[5]);
T* a = reinterpret_cast<T*>(data[6]);
void** out = reinterpret_cast<void**>(out_tuple);
T* a_out = reinterpret_cast<T*>(out[0]);
T* tau = reinterpret_cast<T*>(out[1]);
int* info = reinterpret_cast<int*>(out[2]);
T* work = reinterpret_cast<T*>(out[3]);
if (a_out != a) {
std::memcpy(a_out, a,
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
static_cast<int64_t>(n) * sizeof(T));
}
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
for (int i = 0; i < batch; ++i) {
fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info);
a_out += a_plus;
tau += n - 1;
++info;
}
}
template <typename T>
int64_t Gehrd<T>::Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
lapack_int ihi) {
T work = 0;
lapack_int lwork = -1;
lapack_int info = 0;
fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info);
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
}
template struct Gehrd<float>;
template struct Gehrd<double>;
template struct Gehrd<std::complex<float>>;
template struct Gehrd<std::complex<double>>;
template <typename T>
typename Sytrd<T>::FnType* Sytrd<T>::fn = nullptr;
template <typename T>
void Sytrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
int32_t lda = *reinterpret_cast<int32_t*>(data[2]);
int32_t batch = *reinterpret_cast<int32_t*>(data[3]);
int32_t lwork = *reinterpret_cast<int32_t*>(data[4]);
T* a = reinterpret_cast<T*>(data[5]);
void** out = reinterpret_cast<void**>(out_tuple);
T* a_out = reinterpret_cast<T*>(out[0]);
typedef typename real_type<T>::type Real;
Real* d = reinterpret_cast<Real*>(out[1]);
Real* e = reinterpret_cast<Real*>(out[2]);
T* tau = reinterpret_cast<T*>(out[3]);
int* info = reinterpret_cast<int*>(out[4]);
T* work = reinterpret_cast<T*>(out[5]);
if (a_out != a) {
std::memcpy(a_out, a,
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
static_cast<int64_t>(n) * sizeof(T));
}
char cuplo = lower ? 'L' : 'U';
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
for (int i = 0; i < batch; ++i) {
fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info);
a_out += a_plus;
d += n;
e += n - 1;
tau += n - 1;
++info;
}
}
template <typename T>
int64_t Sytrd<T>::Workspace(lapack_int lda, lapack_int n) {
char cuplo = 'L';
T work = 0;
lapack_int lwork = -1;
lapack_int info = 0;
fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork,
&info);
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
}
template struct Sytrd<float>;
template struct Sytrd<double>;
template struct Sytrd<std::complex<float>>;
template struct Sytrd<std::complex<double>>;
} // namespace jax } // namespace jax

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <complex> #include <complex>
#include <cstdint> #include <cstdint>
#include "tensorflow/compiler/xla/service/custom_call_status.h" #include "tensorflow/compiler/xla/service/custom_call_status.h"
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either // Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
@ -88,8 +89,8 @@ struct RealGesdd {
static FnType* fn; static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*); static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int m, lapack_int n, static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
bool job_opt_compute_uv, bool job_opt_full_matrices); bool job_opt_full_matrices);
}; };
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv); lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv);
@ -104,11 +105,10 @@ struct ComplexGesdd {
static FnType* fn; static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*); static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int m, lapack_int n, static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
bool job_opt_compute_uv, bool job_opt_full_matrices); bool job_opt_full_matrices);
}; };
lapack_int SyevdWorkSize(int64_t n); lapack_int SyevdWorkSize(int64_t n);
lapack_int SyevdIworkSize(int64_t n); lapack_int SyevdIworkSize(int64_t n);
@ -176,6 +176,45 @@ struct ComplexGees {
static void Kernel(void* out, void** data, XlaCustomCallStatus*); static void Kernel(void* out, void** data, XlaCustomCallStatus*);
}; };
// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form.
template <typename T>
struct Gehrd {
using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a,
lapack_int* lda, T* tau, T* work, lapack_int* lwork,
lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
lapack_int ihi);
};
template <typename T>
struct real_type {
typedef T type;
};
template <typename T>
struct real_type<std::complex<T>> {
typedef T type;
};
// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal
// form.
template <typename T>
struct Sytrd {
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
typename real_type<T>::type* d,
typename real_type<T>::type* e,
T* tau, T* work,
lapack_int* lwork, lapack_int* info);
static FnType* fn;
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
static int64_t Workspace(lapack_int lda, lapack_int n);
};
} // namespace jax } // namespace jax
#endif // JAXLIB_CPU_LAPACK_KERNELS_H_ #endif // JAXLIB_CPU_LAPACK_KERNELS_H_

View File

@ -66,6 +66,16 @@ jax::RealGees<double>::FnType dgees_;
jax::ComplexGees<std::complex<float>>::FnType cgees_; jax::ComplexGees<std::complex<float>>::FnType cgees_;
jax::ComplexGees<std::complex<double>>::FnType zgees_; jax::ComplexGees<std::complex<double>>::FnType zgees_;
jax::Gehrd<float>::FnType sgehrd_;
jax::Gehrd<double>::FnType dgehrd_;
jax::Gehrd<std::complex<float>>::FnType cgehrd_;
jax::Gehrd<std::complex<double>>::FnType zgehrd_;
jax::Sytrd<float>::FnType ssytrd_;
jax::Sytrd<double>::FnType dsytrd_;
jax::Sytrd<std::complex<float>>::FnType chetrd_;
jax::Sytrd<std::complex<double>>::FnType zhetrd_;
} // extern "C" } // extern "C"
namespace jax { namespace jax {
@ -107,6 +117,15 @@ static auto init = []() -> int {
RealGees<double>::fn = dgees_; RealGees<double>::fn = dgees_;
ComplexGees<std::complex<float>>::fn = cgees_; ComplexGees<std::complex<float>>::fn = cgees_;
ComplexGees<std::complex<double>>::fn = zgees_; ComplexGees<std::complex<double>>::fn = zgees_;
Gehrd<float>::fn = sgehrd_;
Gehrd<double>::fn = dgehrd_;
Gehrd<std::complex<float>>::fn = cgehrd_;
Gehrd<std::complex<double>>::fn = zgehrd_;
Sytrd<float>::fn = ssytrd_;
Sytrd<double>::fn = dsytrd_;
Sytrd<std::complex<float>>::fn = chetrd_;
Sytrd<std::complex<double>>::fn = zhetrd_;
return 0; return 0;
}(); }();

View File

@ -628,3 +628,117 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
return (out[0], out[3], out[4], out[5]) return (out[0], out[3], out[4], out[5])
else: else:
return (out[0], out[3], out[5]) return (out[0], out[3], out[5])
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
def gehrd_mhlo(dtype, a):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgehrd"
lwork = _lapack.lapack_sgehrd_workspace(n, n, 1, n)
elif dtype == np.float64:
fn = b"lapack_dgehrd"
lwork = _lapack.lapack_dgehrd_workspace(n, n, 1, n)
elif dtype == np.complex64:
fn = b"lapack_cgehrd"
lwork = _lapack.lapack_cgehrd_workspace(n, n, 1, n)
elif dtype == np.complex128:
fn = b"lapack_zgehrd"
lwork = _lapack.lapack_zgehrd_workspace(n, n, 1, n)
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(n), _mhlo_s32(1), _mhlo_s32(n), _mhlo_s32(n), _mhlo_s32(b),
_mhlo_s32(lwork), a],
operand_layouts=[[]] * 6 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={6: 0},
)
return out[:3]
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
def sytrd_mhlo(dtype, a, *, lower):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_ssytrd"
lwork = _lapack.lapack_ssytrd_workspace(n, n)
diag_type = a_type.element_type
elif dtype == np.float64:
fn = b"lapack_dsytrd"
lwork = _lapack.lapack_dsytrd_workspace(n, n)
diag_type = a_type.element_type
elif dtype == np.complex64:
fn = b"lapack_chetrd"
lwork = _lapack.lapack_chetrd_workspace(n, n)
diag_type = ir.ComplexType.get(ir.F32Type.get())
elif dtype == np.complex128:
fn = b"lapack_zhetrd"
lwork = _lapack.lapack_zhetrd_workspace(n, n)
diag_type = ir.ComplexType.get(ir.F64Type.get())
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
fn,
[
a.type,
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(n), _mhlo_s32(1 if lower else 0), _mhlo_s32(max(1, n)),
_mhlo_s32(b), _mhlo_s32(lwork), a],
operand_layouts=[[]] * 5 + [layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={5: 0},
)
return out[:5]

View File

@ -15,6 +15,7 @@
"""Tests for the LAPAX linear algebra module.""" """Tests for the LAPAX linear algebra module."""
from functools import partial from functools import partial
import unittest
import numpy as np import numpy as np
import scipy import scipy
@ -28,6 +29,7 @@ from jax import jit, grad, jvp, vmap
from jax import lax from jax import lax
from jax import numpy as jnp from jax import numpy as jnp
from jax import scipy as jsp from jax import scipy as jsp
from jax._src.lib import version as jaxlib_version
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax.config import config from jax.config import config
@ -1244,6 +1246,96 @@ class ScipyLinalgTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5) self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
self._CompileAndCheck(jsp_func, args_maker) self._CompileAndCheck(jsp_func, args_maker)
@jtu.sample_product(
[dict(shape=shape, k=k)
for shape in [(1, 1), (3, 4, 4), (10, 5)]
# TODO(phawkins): there are some test failures on GPU for k=0
for k in range(1, shape[-1] + 1)],
dtype=float_types + complex_types,
)
def testHouseholderProduct(self, shape, k, dtype):
@partial(np.vectorize, signature='(m,n),(k)->(m,n)')
def reference_fn(a, taus):
if dtype == np.float32:
q, _, info = scipy.linalg.lapack.sorgqr(a, taus)
elif dtype == np.float64:
q, _, info = scipy.linalg.lapack.dorgqr(a, taus)
elif dtype == np.complex64:
q, _, info = scipy.linalg.lapack.cungqr(a, taus)
elif dtype == np.complex128:
q, _, info = scipy.linalg.lapack.zungqr(a, taus)
else:
assert False, dtype
assert info == 0, info
return q
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype), rng(shape[:-2] + (k,), dtype)]
tol = {np.float32: 1e-5, np.complex64: 1e-5, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(reference_fn, lax.linalg.householder_product,
args_maker, rtol=tol, atol=tol)
self._CompileAndCheck(lax.linalg.householder_product, args_maker)
@jtu.sample_product(
shape=[(1, 1), (2, 4, 4), (0, 100, 100), (10, 10)],
dtype=float_types + complex_types,
calc_q=[False, True],
)
@jtu.skip_on_devices("gpu", "tpu")
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
def testHessenberg(self, shape, dtype, calc_q):
rng = jtu.rand_default(self.rng())
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
if calc_q:
sp_func = np.vectorize(partial(scipy.linalg.hessenberg, calc_q=True),
otypes=(dtype, dtype),
signature='(n,n)->(n,n),(n,n)')
else:
sp_func = np.vectorize(scipy.linalg.hessenberg, signature='(n,n)->(n,n)',
otypes=(dtype,))
args_maker = lambda: [rng(shape, dtype)]
# scipy.linalg.hessenberg sometimes returns a float Q matrix for complex
# inputs
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1e-5, atol=1e-5,
check_dtypes=not calc_q)
self._CompileAndCheck(jsp_func, args_maker)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (10, 10)],
dtype=float_types + complex_types,
lower=[False, True],
)
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
@jtu.skip_on_devices("gpu", "tpu")
def testTridiagonal(self, shape, dtype, lower):
rng = jtu.rand_default(self.rng())
def jax_func(a):
return lax.linalg.tridiagonal(a, lower=lower)
@partial(np.vectorize, otypes=(dtype, dtype),
signature='(n,n)->(n,n),(k)')
def sp_func(a):
if dtype == np.float32:
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
elif dtype == np.float64:
c, d, e, tau, info = scipy.linalg.lapack.dsytrd(a, lower=lower)
elif dtype == np.complex64:
c, d, e, tau, info = scipy.linalg.lapack.chetrd(a, lower=lower)
elif dtype == np.complex128:
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
else:
assert False, dtype
del d, e
assert info == 0
return c, tau
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-5, atol=1e-5,
check_dtypes=False)
@jtu.sample_product( @jtu.sample_product(
n=[1, 4, 5, 20, 50, 100], n=[1, 4, 5, 20, 50, 100],
dtype=float_types + complex_types, dtype=float_types + complex_types,