mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg. * Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction. * Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction. None of these primitives are differentiable at the moment. PiperOrigin-RevId: 487224934
This commit is contained in:
parent
30637d052b
commit
1cead779a3
@ -9,8 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
## jax 0.3.25
|
||||
* Changes
|
||||
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
|
||||
* {func}`jax.scipy.linalg.hessenberg` is now supported on CPU only. Requires
|
||||
jaxlib > 0.3.24.
|
||||
* New functions {func}`jax.lax.linalg.hessenberg`,
|
||||
{func}`jax.lax.linalg.tridiagonal`, and
|
||||
{func}`jax.lax.linalg.householder_product` were added. Householder and
|
||||
tridiagonal reductions are supported on CPU only.
|
||||
|
||||
## jaxlib 0.3.25
|
||||
* Changes
|
||||
* Added support for upper Hessenberg and tridiagonal reductions on CPU.
|
||||
|
||||
## jax 0.3.24 (Nov 4, 2022)
|
||||
* Changes
|
||||
|
@ -205,7 +205,9 @@ Linear algebra operators (jax.lax.linalg)
|
||||
cholesky
|
||||
eig
|
||||
eigh
|
||||
hessenberg
|
||||
lu
|
||||
householder_product
|
||||
qdwh
|
||||
qr
|
||||
schur
|
||||
|
@ -30,6 +30,7 @@ jax.scipy.linalg
|
||||
expm
|
||||
expm_frechet
|
||||
funm
|
||||
hessenberg
|
||||
inv
|
||||
lu
|
||||
lu_factor
|
||||
|
@ -753,11 +753,11 @@ mlir.register_lowering(
|
||||
platform='tpu')
|
||||
|
||||
|
||||
triangular_solve_dtype_rule = partial(
|
||||
_triangular_solve_dtype_rule = partial(
|
||||
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
||||
'triangular_solve')
|
||||
|
||||
def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
||||
def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
||||
if a.ndim < 2:
|
||||
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
||||
raise TypeError(msg.format(a.ndim))
|
||||
@ -778,7 +778,7 @@ def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
||||
raise TypeError(msg.format(a.shape, b.shape))
|
||||
return b.shape
|
||||
|
||||
def triangular_solve_jvp_rule_a(
|
||||
def _triangular_solve_jvp_rule_a(
|
||||
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
m, n = b.shape[-2:]
|
||||
@ -811,7 +811,7 @@ def triangular_solve_jvp_rule_a(
|
||||
else:
|
||||
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
|
||||
|
||||
def triangular_solve_transpose_rule(
|
||||
def _triangular_solve_transpose_rule(
|
||||
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
# Triangular solve is nonlinear in its first argument and linear in its second
|
||||
@ -827,7 +827,7 @@ def triangular_solve_transpose_rule(
|
||||
return [None, cotangent_b]
|
||||
|
||||
|
||||
def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
||||
def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
||||
lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
x, y = batched_args
|
||||
@ -856,13 +856,13 @@ def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
||||
unit_diagonal=unit_diagonal), 0
|
||||
|
||||
triangular_solve_p = standard_primitive(
|
||||
triangular_solve_shape_rule, triangular_solve_dtype_rule,
|
||||
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
|
||||
'triangular_solve')
|
||||
ad.defjvp2(triangular_solve_p,
|
||||
triangular_solve_jvp_rule_a,
|
||||
_triangular_solve_jvp_rule_a,
|
||||
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
||||
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
|
||||
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
|
||||
ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule
|
||||
batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule
|
||||
|
||||
|
||||
def _triangular_solve_lowering(
|
||||
@ -1363,9 +1363,9 @@ mlir.register_lowering(
|
||||
platform='rocm')
|
||||
|
||||
|
||||
# orgqr: product of elementary Householder reflectors
|
||||
# householder_product: product of elementary Householder reflectors
|
||||
|
||||
def orgqr(a: ArrayLike, taus: ArrayLike) -> Array:
|
||||
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
|
||||
"""Product of elementary Householder reflectors.
|
||||
|
||||
Args:
|
||||
@ -1378,31 +1378,34 @@ def orgqr(a: ArrayLike, taus: ArrayLike) -> Array:
|
||||
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
||||
containing the products of the elementary Householder reflectors.
|
||||
"""
|
||||
return orgqr_p.bind(a, taus)
|
||||
return householder_product_p.bind(a, taus)
|
||||
|
||||
|
||||
def _orgqr_abstract_eval(a, taus):
|
||||
def _householder_product_abstract_eval(a, taus):
|
||||
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
|
||||
raise NotImplementedError("Unsupported aval in orgqr_abstract_eval: "
|
||||
raise NotImplementedError("Unsupported aval in householder_product_abstract_eval: "
|
||||
f"{a.aval} {taus.aval}")
|
||||
if a.ndim < 2:
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
raise ValueError("Argument to Householder product must have ndims >= 2")
|
||||
*batch_dims, m, n = a.shape
|
||||
*taus_batch_dims, k = taus.shape
|
||||
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
|
||||
raise ValueError(f"Type mismatch for orgqr: a={a} taus={taus}")
|
||||
raise ValueError(f"Type mismatch for Householder product: a={a} taus={taus}")
|
||||
if m < n:
|
||||
raise ValueError("Householder product inputs must have at least as many "
|
||||
f"rows as columns, got shape {a.shape}")
|
||||
return a
|
||||
|
||||
def _orgqr_batching_rule(batched_args, batch_dims):
|
||||
def _householder_product_batching_rule(batched_args, batch_dims):
|
||||
a, taus = batched_args
|
||||
b_a, b_taus, = batch_dims
|
||||
return orgqr(batching.moveaxis(a, b_a, 0),
|
||||
return householder_product(batching.moveaxis(a, b_a, 0),
|
||||
batching.moveaxis(taus, b_taus, 0)), (0,)
|
||||
|
||||
def _orgqr_translation_rule(ctx, avals_in, avals_out, a, taus):
|
||||
def _householder_product_translation_rule(ctx, avals_in, avals_out, a, taus):
|
||||
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
|
||||
|
||||
def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||
def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||
a_aval, _ = ctx.avals_in
|
||||
*batch_dims, m, n = a_aval.shape
|
||||
|
||||
@ -1420,22 +1423,23 @@ def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||
return [a]
|
||||
|
||||
|
||||
orgqr_p = Primitive('orgqr')
|
||||
orgqr_p.def_impl(partial(xla.apply_primitive, orgqr_p))
|
||||
orgqr_p.def_abstract_eval(_orgqr_abstract_eval)
|
||||
batching.primitive_batchers[orgqr_p] = _orgqr_batching_rule
|
||||
xla.register_translation(orgqr_p, _orgqr_translation_rule)
|
||||
householder_product_p = Primitive('householder_product')
|
||||
householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p))
|
||||
householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
|
||||
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
|
||||
xla.register_translation(householder_product_p, _householder_product_translation_rule)
|
||||
|
||||
mlir.register_lowering(
|
||||
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
|
||||
householder_product_p,
|
||||
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
||||
householder_product_p,
|
||||
partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
orgqr_p,
|
||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
||||
householder_product_p,
|
||||
partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
||||
platform='rocm')
|
||||
|
||||
|
||||
@ -1492,13 +1496,13 @@ def _qr_lowering(a, *, full_matrices):
|
||||
|
||||
r, taus = geqrf(a)
|
||||
if m < n:
|
||||
q = orgqr(r[..., :m, :m], taus)
|
||||
q = householder_product(r[..., :m, :m], taus)
|
||||
elif full_matrices:
|
||||
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
|
||||
q = lax.pad(r, lax_internal._zero(r), pads)
|
||||
q = orgqr(q, taus)
|
||||
q = householder_product(q, taus)
|
||||
else:
|
||||
q = orgqr(r, taus)
|
||||
q = householder_product(r, taus)
|
||||
r = r[..., :n, :n]
|
||||
r = jnp.triu(r)
|
||||
return q, r
|
||||
@ -1920,6 +1924,146 @@ batching.primitive_batchers[schur_p] = _schur_batching_rule
|
||||
ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
||||
|
||||
|
||||
# hessenberg: Upper Hessenberg reduction
|
||||
|
||||
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
|
||||
"""Reduces a square matrix to upper Hessenberg form.
|
||||
|
||||
Args:
|
||||
a: A floating point or complex square matrix or batch of matrices.
|
||||
|
||||
Returns:
|
||||
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of ``a``
|
||||
contain the upper Hessenberg matrix, and the elements below the first
|
||||
subdiagonal contain the Householder reflectors. For each Householder
|
||||
reflector ``taus`` contains the scalar factors of the elementary Householder
|
||||
reflectors.
|
||||
"""
|
||||
return hessenberg_p.bind(a)
|
||||
|
||||
def _hessenberg_abstract_eval(a):
|
||||
if a.ndim < 2:
|
||||
msg = "hessenberg requires a.ndim to be at least 2, got {}."
|
||||
raise TypeError(msg.format(a.ndim))
|
||||
if a.shape[-1] != a.shape[-2]:
|
||||
msg = ("hessenberg requires the last two dimensions of a to be equal "
|
||||
"in size, got a.shape of {}.")
|
||||
raise TypeError(msg.format(a.shape))
|
||||
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
|
||||
|
||||
hessenberg_p = Primitive("hessenberg")
|
||||
hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p))
|
||||
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
|
||||
hessenberg_p.multiple_results = True
|
||||
|
||||
def _hessenberg_batching_rule(batched_args, batch_dims):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return hessenberg(x), 0
|
||||
|
||||
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
|
||||
|
||||
def _hessenberg_cpu_mhlo(ctx, a):
|
||||
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
|
||||
if not hasattr(lapack, "gehrd_mhlo"):
|
||||
raise RuntimeError("Hessenberg reduction on CPU requires jaxlib 0.3.25 or "
|
||||
"newer")
|
||||
a_aval, = ctx.avals_in
|
||||
batch_dims = a_aval.shape[:-2]
|
||||
a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a)
|
||||
ok = mlir.compare_mhlo(
|
||||
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
return [
|
||||
_broadcasting_select_mhlo(
|
||||
mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get(batch_dims + (1, 1),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
|
||||
a, _nan_like_mhlo(ctx.avals_out[0])),
|
||||
_broadcasting_select_mhlo(
|
||||
mhlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get(batch_dims + (1,),
|
||||
ir.IntegerType.get_signless(1)),
|
||||
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
|
||||
taus, _nan_like_mhlo(ctx.avals_out[1])),
|
||||
]
|
||||
|
||||
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
|
||||
|
||||
|
||||
# tridiagonal: Upper Hessenberg reduction
|
||||
|
||||
def tridiagonal(a: ArrayLike, *, lower=True) -> Tuple[Array, Array]:
|
||||
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
|
||||
|
||||
Args:
|
||||
a: A floating point or complex matrix or batch of matrices.
|
||||
lower: Describes which triangle of the input matrices to use.
|
||||
The other triangle is ignored and not accessed.
|
||||
|
||||
Returns:
|
||||
A ``(a, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
|
||||
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
|
||||
and elements below the first subdiagonal contain the elementary Householder
|
||||
reflectors. If ``lower=False`` the diagonal and first superdiagonal of the
|
||||
matrix contains the tridiagonal representation, and elements above the first
|
||||
superdiagonal contain the elementary Householder reflectors.
|
||||
``taus`` contains the scalar factors of the elementary Householder
|
||||
reflectors.
|
||||
"""
|
||||
arr, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
|
||||
nan = arr.dtype.type(jnp.nan)
|
||||
if jnp.issubdtype(arr.dtype, np.complexfloating):
|
||||
nan = nan + arr.dtype.type(jnp.nan * 1j)
|
||||
arr = jnp.where((info == 0)[..., None, None], arr, nan)
|
||||
taus = jnp.where((info == 0)[..., None], taus, nan)
|
||||
return arr, taus
|
||||
|
||||
def _tridiagonal_abstract_eval(a, *, lower):
|
||||
if a.ndim < 2:
|
||||
msg = "tridiagonal requires a.ndim to be at least 2, got {}."
|
||||
raise TypeError(msg.format(a.ndim))
|
||||
if a.shape[-1] != a.shape[-2]:
|
||||
msg = ("tridiagonal requires the last two dimensions of a to be equal "
|
||||
"in size, got a.shape of {}.")
|
||||
raise TypeError(msg.format(a.shape))
|
||||
if a.shape[-1] == 0:
|
||||
msg = ("tridiagonal requires the last two dimensions of a to be non-zero, "
|
||||
"got a.shape of {}.")
|
||||
raise TypeError(msg.format(a.shape))
|
||||
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
|
||||
ShapedArray(a.shape[:-2], np.int32)]
|
||||
|
||||
tridiagonal_p = Primitive("tridiagonal")
|
||||
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
|
||||
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
|
||||
tridiagonal_p.multiple_results = True
|
||||
|
||||
def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return tridiagonal(x), 0
|
||||
|
||||
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
||||
|
||||
def _tridiagonal_cpu_mhlo(ctx, a, *, lower):
|
||||
a_aval, = ctx.avals_in
|
||||
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
|
||||
if not hasattr(lapack, "sytrd_mhlo"):
|
||||
raise RuntimeError("Tridiagonal reduction on CPU requires jaxlib 0.3.25 or "
|
||||
"newer")
|
||||
|
||||
a, d, e, taus, info = lapack.sytrd_mhlo(a_aval.dtype, a, lower=lower)
|
||||
del d, e
|
||||
return a, taus, info
|
||||
|
||||
mlir.register_lowering(tridiagonal_p, _tridiagonal_cpu_mhlo, platform='cpu')
|
||||
|
||||
|
||||
|
||||
# Utilities
|
||||
|
||||
def _nan_like_mhlo(aval):
|
||||
|
@ -1001,3 +1001,34 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Arra
|
||||
return T, Z
|
||||
|
||||
return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z))
|
||||
|
||||
@overload
|
||||
def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Array: ...
|
||||
|
||||
@overload
|
||||
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Tuple[Array, Array]: ...
|
||||
|
||||
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
|
||||
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
|
||||
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
||||
check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
|
||||
del overwrite_a, check_finite
|
||||
n = jnp.shape(a)[-1]
|
||||
if n == 0:
|
||||
if calc_q:
|
||||
return jnp.zeros_like(a), jnp.zeros_like(a)
|
||||
else:
|
||||
return jnp.zeros_like(a)
|
||||
a_out, taus = lax_linalg.hessenberg(a)
|
||||
h = jnp.triu(a_out, -1)
|
||||
if calc_q:
|
||||
q = lax_linalg.householder_product(a_out[..., 1:, :-1], taus)
|
||||
batch_dims = a_out.shape[:-2]
|
||||
q = jnp.block([[jnp.ones(batch_dims + (1, 1), dtype=a_out.dtype),
|
||||
jnp.zeros(batch_dims + (1, n - 1), dtype=a_out.dtype)],
|
||||
[jnp.zeros(batch_dims + (n - 1, 1), dtype=a_out.dtype), q]])
|
||||
return h, q
|
||||
else:
|
||||
return h
|
||||
|
@ -1253,7 +1253,9 @@ tf_not_yet_impl = [
|
||||
"lu_pivots_to_permutation",
|
||||
"xla_pmap",
|
||||
"geqrf",
|
||||
"orgqr",
|
||||
"householder_product",
|
||||
"hessenberg",
|
||||
"tridiagonal",
|
||||
"eigh_jacobi",
|
||||
]
|
||||
|
||||
|
@ -19,15 +19,21 @@ from jax._src.lax.linalg import (
|
||||
eig_p,
|
||||
eigh,
|
||||
eigh_p,
|
||||
hessenberg,
|
||||
hessenberg_p,
|
||||
lu,
|
||||
lu_p,
|
||||
lu_pivots_to_permutation,
|
||||
householder_product,
|
||||
householder_product_p,
|
||||
qr,
|
||||
qr_p,
|
||||
svd,
|
||||
svd_p,
|
||||
triangular_solve,
|
||||
triangular_solve_p,
|
||||
tridiagonal,
|
||||
tridiagonal_p,
|
||||
tridiagonal_solve,
|
||||
tridiagonal_solve_p,
|
||||
schur,
|
||||
|
@ -22,6 +22,7 @@ from jax._src.scipy.linalg import (
|
||||
eigh_tridiagonal as eigh_tridiagonal,
|
||||
expm as expm,
|
||||
expm_frechet as expm_frechet,
|
||||
hessenberg as hessenberg,
|
||||
inv as inv,
|
||||
lu as lu,
|
||||
lu_factor as lu_factor,
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "jaxlib/cpu/lapack_kernels.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
|
||||
namespace jax {
|
||||
@ -127,6 +127,27 @@ void GetLapackKernelsFromScipy() {
|
||||
ComplexGees<std::complex<double>>::fn =
|
||||
reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgees"));
|
||||
Gehrd<float>::fn =
|
||||
reinterpret_cast<Gehrd<float>::FnType*>(lapack_ptr("sgehrd"));
|
||||
Gehrd<double>::fn =
|
||||
reinterpret_cast<Gehrd<double>::FnType*>(lapack_ptr("dgehrd"));
|
||||
Gehrd<std::complex<float>>::fn =
|
||||
reinterpret_cast<Gehrd<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("cgehrd"));
|
||||
Gehrd<std::complex<double>>::fn =
|
||||
reinterpret_cast<Gehrd<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zgehrd"));
|
||||
Sytrd<float>::fn =
|
||||
reinterpret_cast<Sytrd<float>::FnType*>(lapack_ptr("ssytrd"));
|
||||
Sytrd<double>::fn =
|
||||
reinterpret_cast<Sytrd<double>::FnType*>(lapack_ptr("dsytrd"));
|
||||
Sytrd<std::complex<float>>::fn =
|
||||
reinterpret_cast<Sytrd<std::complex<float>>::FnType*>(
|
||||
lapack_ptr("chetrd"));
|
||||
Sytrd<std::complex<double>>::fn =
|
||||
reinterpret_cast<Sytrd<std::complex<double>>::FnType*>(
|
||||
lapack_ptr("zhetrd"));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
|
||||
@ -185,6 +206,21 @@ py::dict Registrations() {
|
||||
EncapsulateFunction(ComplexGees<std::complex<float>>::Kernel);
|
||||
dict["lapack_zgees"] =
|
||||
EncapsulateFunction(ComplexGees<std::complex<double>>::Kernel);
|
||||
|
||||
dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd<float>::Kernel);
|
||||
dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd<double>::Kernel);
|
||||
dict["lapack_cgehrd"] =
|
||||
EncapsulateFunction(Gehrd<std::complex<float>>::Kernel);
|
||||
dict["lapack_zgehrd"] =
|
||||
EncapsulateFunction(Gehrd<std::complex<double>>::Kernel);
|
||||
|
||||
dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd<float>::Kernel);
|
||||
dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd<double>::Kernel);
|
||||
dict["lapack_chetrd"] =
|
||||
EncapsulateFunction(Sytrd<std::complex<float>>::Kernel);
|
||||
dict["lapack_zhetrd"] =
|
||||
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
||||
@ -211,6 +247,14 @@ PYBIND11_MODULE(_lapack, m) {
|
||||
m.def("syevd_iwork_size", &SyevdIworkSize);
|
||||
m.def("heevd_work_size", &HeevdWorkSize);
|
||||
m.def("heevd_rwork_size", &HeevdRworkSize);
|
||||
m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace);
|
||||
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace);
|
||||
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace);
|
||||
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace);
|
||||
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace);
|
||||
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace);
|
||||
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace);
|
||||
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -775,4 +775,111 @@ template struct RealGees<double>;
|
||||
template struct ComplexGees<std::complex<float>>;
|
||||
template struct ComplexGees<std::complex<double>>;
|
||||
|
||||
template <typename T>
|
||||
typename Gehrd<T>::FnType* Gehrd<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void Gehrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
|
||||
int32_t ilo = *reinterpret_cast<int32_t*>(data[1]);
|
||||
int32_t ihi = *reinterpret_cast<int32_t*>(data[2]);
|
||||
int32_t lda = *reinterpret_cast<int32_t*>(data[3]);
|
||||
int32_t batch = *reinterpret_cast<int32_t*>(data[4]);
|
||||
int32_t lwork = *reinterpret_cast<int32_t*>(data[5]);
|
||||
T* a = reinterpret_cast<T*>(data[6]);
|
||||
|
||||
void** out = reinterpret_cast<void**>(out_tuple);
|
||||
T* a_out = reinterpret_cast<T*>(out[0]);
|
||||
T* tau = reinterpret_cast<T*>(out[1]);
|
||||
int* info = reinterpret_cast<int*>(out[2]);
|
||||
T* work = reinterpret_cast<T*>(out[3]);
|
||||
|
||||
if (a_out != a) {
|
||||
std::memcpy(a_out, a,
|
||||
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
|
||||
static_cast<int64_t>(n) * sizeof(T));
|
||||
}
|
||||
|
||||
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
|
||||
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info);
|
||||
a_out += a_plus;
|
||||
tau += n - 1;
|
||||
++info;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int64_t Gehrd<T>::Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
|
||||
lapack_int ihi) {
|
||||
T work = 0;
|
||||
lapack_int lwork = -1;
|
||||
lapack_int info = 0;
|
||||
fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
|
||||
}
|
||||
|
||||
template struct Gehrd<float>;
|
||||
template struct Gehrd<double>;
|
||||
template struct Gehrd<std::complex<float>>;
|
||||
template struct Gehrd<std::complex<double>>;
|
||||
|
||||
template <typename T>
|
||||
typename Sytrd<T>::FnType* Sytrd<T>::fn = nullptr;
|
||||
|
||||
template <typename T>
|
||||
void Sytrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
|
||||
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
|
||||
int32_t lda = *reinterpret_cast<int32_t*>(data[2]);
|
||||
int32_t batch = *reinterpret_cast<int32_t*>(data[3]);
|
||||
int32_t lwork = *reinterpret_cast<int32_t*>(data[4]);
|
||||
T* a = reinterpret_cast<T*>(data[5]);
|
||||
|
||||
void** out = reinterpret_cast<void**>(out_tuple);
|
||||
T* a_out = reinterpret_cast<T*>(out[0]);
|
||||
typedef typename real_type<T>::type Real;
|
||||
Real* d = reinterpret_cast<Real*>(out[1]);
|
||||
Real* e = reinterpret_cast<Real*>(out[2]);
|
||||
T* tau = reinterpret_cast<T*>(out[3]);
|
||||
int* info = reinterpret_cast<int*>(out[4]);
|
||||
T* work = reinterpret_cast<T*>(out[5]);
|
||||
|
||||
if (a_out != a) {
|
||||
std::memcpy(a_out, a,
|
||||
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
|
||||
static_cast<int64_t>(n) * sizeof(T));
|
||||
}
|
||||
|
||||
char cuplo = lower ? 'L' : 'U';
|
||||
|
||||
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
|
||||
|
||||
for (int i = 0; i < batch; ++i) {
|
||||
fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info);
|
||||
a_out += a_plus;
|
||||
d += n;
|
||||
e += n - 1;
|
||||
tau += n - 1;
|
||||
++info;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int64_t Sytrd<T>::Workspace(lapack_int lda, lapack_int n) {
|
||||
char cuplo = 'L';
|
||||
T work = 0;
|
||||
lapack_int lwork = -1;
|
||||
lapack_int info = 0;
|
||||
fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork,
|
||||
&info);
|
||||
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
|
||||
}
|
||||
|
||||
template struct Sytrd<float>;
|
||||
template struct Sytrd<double>;
|
||||
template struct Sytrd<std::complex<float>>;
|
||||
template struct Sytrd<std::complex<double>>;
|
||||
|
||||
} // namespace jax
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
|
||||
@ -88,8 +89,8 @@ struct RealGesdd {
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
|
||||
static int64_t Workspace(lapack_int m, lapack_int n,
|
||||
bool job_opt_compute_uv, bool job_opt_full_matrices);
|
||||
static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
|
||||
bool job_opt_full_matrices);
|
||||
};
|
||||
|
||||
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv);
|
||||
@ -104,11 +105,10 @@ struct ComplexGesdd {
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
|
||||
static int64_t Workspace(lapack_int m, lapack_int n,
|
||||
bool job_opt_compute_uv, bool job_opt_full_matrices);
|
||||
static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
|
||||
bool job_opt_full_matrices);
|
||||
};
|
||||
|
||||
|
||||
lapack_int SyevdWorkSize(int64_t n);
|
||||
lapack_int SyevdIworkSize(int64_t n);
|
||||
|
||||
@ -176,6 +176,45 @@ struct ComplexGees {
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
};
|
||||
|
||||
// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form.
|
||||
template <typename T>
|
||||
struct Gehrd {
|
||||
using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a,
|
||||
lapack_int* lda, T* tau, T* work, lapack_int* lwork,
|
||||
lapack_int* info);
|
||||
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
|
||||
static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
|
||||
lapack_int ihi);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct real_type {
|
||||
typedef T type;
|
||||
};
|
||||
template <typename T>
|
||||
struct real_type<std::complex<T>> {
|
||||
typedef T type;
|
||||
};
|
||||
|
||||
// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal
|
||||
// form.
|
||||
template <typename T>
|
||||
struct Sytrd {
|
||||
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
|
||||
typename real_type<T>::type* d,
|
||||
typename real_type<T>::type* e,
|
||||
T* tau, T* work,
|
||||
lapack_int* lwork, lapack_int* info);
|
||||
|
||||
static FnType* fn;
|
||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||
|
||||
static int64_t Workspace(lapack_int lda, lapack_int n);
|
||||
};
|
||||
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_CPU_LAPACK_KERNELS_H_
|
||||
|
@ -66,6 +66,16 @@ jax::RealGees<double>::FnType dgees_;
|
||||
jax::ComplexGees<std::complex<float>>::FnType cgees_;
|
||||
jax::ComplexGees<std::complex<double>>::FnType zgees_;
|
||||
|
||||
jax::Gehrd<float>::FnType sgehrd_;
|
||||
jax::Gehrd<double>::FnType dgehrd_;
|
||||
jax::Gehrd<std::complex<float>>::FnType cgehrd_;
|
||||
jax::Gehrd<std::complex<double>>::FnType zgehrd_;
|
||||
|
||||
jax::Sytrd<float>::FnType ssytrd_;
|
||||
jax::Sytrd<double>::FnType dsytrd_;
|
||||
jax::Sytrd<std::complex<float>>::FnType chetrd_;
|
||||
jax::Sytrd<std::complex<double>>::FnType zhetrd_;
|
||||
|
||||
} // extern "C"
|
||||
|
||||
namespace jax {
|
||||
@ -107,6 +117,15 @@ static auto init = []() -> int {
|
||||
RealGees<double>::fn = dgees_;
|
||||
ComplexGees<std::complex<float>>::fn = cgees_;
|
||||
ComplexGees<std::complex<double>>::fn = zgees_;
|
||||
Gehrd<float>::fn = sgehrd_;
|
||||
Gehrd<double>::fn = dgehrd_;
|
||||
Gehrd<std::complex<float>>::fn = cgehrd_;
|
||||
Gehrd<std::complex<double>>::fn = zgehrd_;
|
||||
Sytrd<float>::fn = ssytrd_;
|
||||
Sytrd<double>::fn = dsytrd_;
|
||||
Sytrd<std::complex<float>>::fn = chetrd_;
|
||||
Sytrd<std::complex<double>>::fn = zhetrd_;
|
||||
|
||||
return 0;
|
||||
}();
|
||||
|
||||
|
114
jaxlib/lapack.py
114
jaxlib/lapack.py
@ -628,3 +628,117 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
return (out[0], out[3], out[4], out[5])
|
||||
else:
|
||||
return (out[0], out[3], out[5])
|
||||
|
||||
|
||||
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
|
||||
def gehrd_mhlo(dtype, a):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n, (m, n)
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
|
||||
if dtype == np.float32:
|
||||
fn = b"lapack_sgehrd"
|
||||
lwork = _lapack.lapack_sgehrd_workspace(n, n, 1, n)
|
||||
elif dtype == np.float64:
|
||||
fn = b"lapack_dgehrd"
|
||||
lwork = _lapack.lapack_dgehrd_workspace(n, n, 1, n)
|
||||
elif dtype == np.complex64:
|
||||
fn = b"lapack_cgehrd"
|
||||
lwork = _lapack.lapack_cgehrd_workspace(n, n, 1, n)
|
||||
elif dtype == np.complex128:
|
||||
fn = b"lapack_zgehrd"
|
||||
lwork = _lapack.lapack_zgehrd_workspace(n, n, 1, n)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = custom_call(
|
||||
fn,
|
||||
[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
[_mhlo_s32(n), _mhlo_s32(1), _mhlo_s32(n), _mhlo_s32(n), _mhlo_s32(b),
|
||||
_mhlo_s32(lwork), a],
|
||||
operand_layouts=[[]] * 6 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={6: 0},
|
||||
)
|
||||
return out[:3]
|
||||
|
||||
|
||||
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
|
||||
def sytrd_mhlo(dtype, a, *, lower):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
assert len(dims) >= 2
|
||||
m, n = dims[-2:]
|
||||
assert m == n, (m, n)
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = 1
|
||||
for d in batch_dims:
|
||||
b *= d
|
||||
|
||||
if dtype == np.float32:
|
||||
fn = b"lapack_ssytrd"
|
||||
lwork = _lapack.lapack_ssytrd_workspace(n, n)
|
||||
diag_type = a_type.element_type
|
||||
elif dtype == np.float64:
|
||||
fn = b"lapack_dsytrd"
|
||||
lwork = _lapack.lapack_dsytrd_workspace(n, n)
|
||||
diag_type = a_type.element_type
|
||||
elif dtype == np.complex64:
|
||||
fn = b"lapack_chetrd"
|
||||
lwork = _lapack.lapack_chetrd_workspace(n, n)
|
||||
diag_type = ir.ComplexType.get(ir.F32Type.get())
|
||||
elif dtype == np.complex128:
|
||||
fn = b"lapack_zhetrd"
|
||||
lwork = _lapack.lapack_zhetrd_workspace(n, n)
|
||||
diag_type = ir.ComplexType.get(ir.F64Type.get())
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = custom_call(
|
||||
fn,
|
||||
[
|
||||
a.type,
|
||||
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
|
||||
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
[_mhlo_s32(n), _mhlo_s32(1 if lower else 0), _mhlo_s32(max(1, n)),
|
||||
_mhlo_s32(b), _mhlo_s32(lwork), a],
|
||||
operand_layouts=[[]] * 5 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||
tuple(range(num_bd - 1, -1, -1)),
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={5: 0},
|
||||
)
|
||||
return out[:5]
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Tests for the LAPAX linear algebra module."""
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
@ -28,6 +29,7 @@ from jax import jit, grad, jvp, vmap
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import scipy as jsp
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax.config import config
|
||||
@ -1244,6 +1246,96 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||
self._CompileAndCheck(jsp_func, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, k=k)
|
||||
for shape in [(1, 1), (3, 4, 4), (10, 5)]
|
||||
# TODO(phawkins): there are some test failures on GPU for k=0
|
||||
for k in range(1, shape[-1] + 1)],
|
||||
dtype=float_types + complex_types,
|
||||
)
|
||||
def testHouseholderProduct(self, shape, k, dtype):
|
||||
|
||||
@partial(np.vectorize, signature='(m,n),(k)->(m,n)')
|
||||
def reference_fn(a, taus):
|
||||
if dtype == np.float32:
|
||||
q, _, info = scipy.linalg.lapack.sorgqr(a, taus)
|
||||
elif dtype == np.float64:
|
||||
q, _, info = scipy.linalg.lapack.dorgqr(a, taus)
|
||||
elif dtype == np.complex64:
|
||||
q, _, info = scipy.linalg.lapack.cungqr(a, taus)
|
||||
elif dtype == np.complex128:
|
||||
q, _, info = scipy.linalg.lapack.zungqr(a, taus)
|
||||
else:
|
||||
assert False, dtype
|
||||
assert info == 0, info
|
||||
return q
|
||||
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype), rng(shape[:-2] + (k,), dtype)]
|
||||
tol = {np.float32: 1e-5, np.complex64: 1e-5, np.float64: 1e-12,
|
||||
np.complex128: 1e-12}
|
||||
self._CheckAgainstNumpy(reference_fn, lax.linalg.householder_product,
|
||||
args_maker, rtol=tol, atol=tol)
|
||||
self._CompileAndCheck(lax.linalg.householder_product, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (2, 4, 4), (0, 100, 100), (10, 10)],
|
||||
dtype=float_types + complex_types,
|
||||
calc_q=[False, True],
|
||||
)
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
|
||||
def testHessenberg(self, shape, dtype, calc_q):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
|
||||
if calc_q:
|
||||
sp_func = np.vectorize(partial(scipy.linalg.hessenberg, calc_q=True),
|
||||
otypes=(dtype, dtype),
|
||||
signature='(n,n)->(n,n),(n,n)')
|
||||
else:
|
||||
sp_func = np.vectorize(scipy.linalg.hessenberg, signature='(n,n)->(n,n)',
|
||||
otypes=(dtype,))
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
# scipy.linalg.hessenberg sometimes returns a float Q matrix for complex
|
||||
# inputs
|
||||
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1e-5, atol=1e-5,
|
||||
check_dtypes=not calc_q)
|
||||
self._CompileAndCheck(jsp_func, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1, 1), (4, 4), (10, 10)],
|
||||
dtype=float_types + complex_types,
|
||||
lower=[False, True],
|
||||
)
|
||||
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
|
||||
@jtu.skip_on_devices("gpu", "tpu")
|
||||
def testTridiagonal(self, shape, dtype, lower):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
def jax_func(a):
|
||||
return lax.linalg.tridiagonal(a, lower=lower)
|
||||
|
||||
@partial(np.vectorize, otypes=(dtype, dtype),
|
||||
signature='(n,n)->(n,n),(k)')
|
||||
def sp_func(a):
|
||||
if dtype == np.float32:
|
||||
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
|
||||
elif dtype == np.float64:
|
||||
c, d, e, tau, info = scipy.linalg.lapack.dsytrd(a, lower=lower)
|
||||
elif dtype == np.complex64:
|
||||
c, d, e, tau, info = scipy.linalg.lapack.chetrd(a, lower=lower)
|
||||
elif dtype == np.complex128:
|
||||
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
|
||||
else:
|
||||
assert False, dtype
|
||||
del d, e
|
||||
assert info == 0
|
||||
return c, tau
|
||||
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-5, atol=1e-5,
|
||||
check_dtypes=False)
|
||||
|
||||
|
||||
@jtu.sample_product(
|
||||
n=[1, 4, 5, 20, 50, 100],
|
||||
dtype=float_types + complex_types,
|
||||
|
Loading…
x
Reference in New Issue
Block a user