mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg. * Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction. * Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction. None of these primitives are differentiable at the moment. PiperOrigin-RevId: 487224934
This commit is contained in:
parent
30637d052b
commit
1cead779a3
@ -9,8 +9,16 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
## jax 0.3.25
|
## jax 0.3.25
|
||||||
* Changes
|
* Changes
|
||||||
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
|
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
|
||||||
|
* {func}`jax.scipy.linalg.hessenberg` is now supported on CPU only. Requires
|
||||||
|
jaxlib > 0.3.24.
|
||||||
|
* New functions {func}`jax.lax.linalg.hessenberg`,
|
||||||
|
{func}`jax.lax.linalg.tridiagonal`, and
|
||||||
|
{func}`jax.lax.linalg.householder_product` were added. Householder and
|
||||||
|
tridiagonal reductions are supported on CPU only.
|
||||||
|
|
||||||
## jaxlib 0.3.25
|
## jaxlib 0.3.25
|
||||||
|
* Changes
|
||||||
|
* Added support for upper Hessenberg and tridiagonal reductions on CPU.
|
||||||
|
|
||||||
## jax 0.3.24 (Nov 4, 2022)
|
## jax 0.3.24 (Nov 4, 2022)
|
||||||
* Changes
|
* Changes
|
||||||
|
@ -205,7 +205,9 @@ Linear algebra operators (jax.lax.linalg)
|
|||||||
cholesky
|
cholesky
|
||||||
eig
|
eig
|
||||||
eigh
|
eigh
|
||||||
|
hessenberg
|
||||||
lu
|
lu
|
||||||
|
householder_product
|
||||||
qdwh
|
qdwh
|
||||||
qr
|
qr
|
||||||
schur
|
schur
|
||||||
|
@ -30,6 +30,7 @@ jax.scipy.linalg
|
|||||||
expm
|
expm
|
||||||
expm_frechet
|
expm_frechet
|
||||||
funm
|
funm
|
||||||
|
hessenberg
|
||||||
inv
|
inv
|
||||||
lu
|
lu
|
||||||
lu_factor
|
lu_factor
|
||||||
|
@ -753,11 +753,11 @@ mlir.register_lowering(
|
|||||||
platform='tpu')
|
platform='tpu')
|
||||||
|
|
||||||
|
|
||||||
triangular_solve_dtype_rule = partial(
|
_triangular_solve_dtype_rule = partial(
|
||||||
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
||||||
'triangular_solve')
|
'triangular_solve')
|
||||||
|
|
||||||
def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
||||||
if a.ndim < 2:
|
if a.ndim < 2:
|
||||||
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
||||||
raise TypeError(msg.format(a.ndim))
|
raise TypeError(msg.format(a.ndim))
|
||||||
@ -778,7 +778,7 @@ def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
|||||||
raise TypeError(msg.format(a.shape, b.shape))
|
raise TypeError(msg.format(a.shape, b.shape))
|
||||||
return b.shape
|
return b.shape
|
||||||
|
|
||||||
def triangular_solve_jvp_rule_a(
|
def _triangular_solve_jvp_rule_a(
|
||||||
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||||
unit_diagonal):
|
unit_diagonal):
|
||||||
m, n = b.shape[-2:]
|
m, n = b.shape[-2:]
|
||||||
@ -811,7 +811,7 @@ def triangular_solve_jvp_rule_a(
|
|||||||
else:
|
else:
|
||||||
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
|
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
|
||||||
|
|
||||||
def triangular_solve_transpose_rule(
|
def _triangular_solve_transpose_rule(
|
||||||
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||||
unit_diagonal):
|
unit_diagonal):
|
||||||
# Triangular solve is nonlinear in its first argument and linear in its second
|
# Triangular solve is nonlinear in its first argument and linear in its second
|
||||||
@ -827,7 +827,7 @@ def triangular_solve_transpose_rule(
|
|||||||
return [None, cotangent_b]
|
return [None, cotangent_b]
|
||||||
|
|
||||||
|
|
||||||
def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
||||||
lower, transpose_a, conjugate_a,
|
lower, transpose_a, conjugate_a,
|
||||||
unit_diagonal):
|
unit_diagonal):
|
||||||
x, y = batched_args
|
x, y = batched_args
|
||||||
@ -856,13 +856,13 @@ def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
|||||||
unit_diagonal=unit_diagonal), 0
|
unit_diagonal=unit_diagonal), 0
|
||||||
|
|
||||||
triangular_solve_p = standard_primitive(
|
triangular_solve_p = standard_primitive(
|
||||||
triangular_solve_shape_rule, triangular_solve_dtype_rule,
|
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
|
||||||
'triangular_solve')
|
'triangular_solve')
|
||||||
ad.defjvp2(triangular_solve_p,
|
ad.defjvp2(triangular_solve_p,
|
||||||
triangular_solve_jvp_rule_a,
|
_triangular_solve_jvp_rule_a,
|
||||||
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
|
||||||
ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
|
ad.primitive_transposes[triangular_solve_p] = _triangular_solve_transpose_rule
|
||||||
batching.primitive_batchers[triangular_solve_p] = triangular_solve_batching_rule
|
batching.primitive_batchers[triangular_solve_p] = _triangular_solve_batching_rule
|
||||||
|
|
||||||
|
|
||||||
def _triangular_solve_lowering(
|
def _triangular_solve_lowering(
|
||||||
@ -1363,9 +1363,9 @@ mlir.register_lowering(
|
|||||||
platform='rocm')
|
platform='rocm')
|
||||||
|
|
||||||
|
|
||||||
# orgqr: product of elementary Householder reflectors
|
# householder_product: product of elementary Householder reflectors
|
||||||
|
|
||||||
def orgqr(a: ArrayLike, taus: ArrayLike) -> Array:
|
def householder_product(a: ArrayLike, taus: ArrayLike) -> Array:
|
||||||
"""Product of elementary Householder reflectors.
|
"""Product of elementary Householder reflectors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1378,31 +1378,34 @@ def orgqr(a: ArrayLike, taus: ArrayLike) -> Array:
|
|||||||
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
A batch of orthogonal (unitary) matrices with the same shape as ``a``,
|
||||||
containing the products of the elementary Householder reflectors.
|
containing the products of the elementary Householder reflectors.
|
||||||
"""
|
"""
|
||||||
return orgqr_p.bind(a, taus)
|
return householder_product_p.bind(a, taus)
|
||||||
|
|
||||||
|
|
||||||
def _orgqr_abstract_eval(a, taus):
|
def _householder_product_abstract_eval(a, taus):
|
||||||
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
|
if not isinstance(a, ShapedArray) or not isinstance(taus, ShapedArray):
|
||||||
raise NotImplementedError("Unsupported aval in orgqr_abstract_eval: "
|
raise NotImplementedError("Unsupported aval in householder_product_abstract_eval: "
|
||||||
f"{a.aval} {taus.aval}")
|
f"{a.aval} {taus.aval}")
|
||||||
if a.ndim < 2:
|
if a.ndim < 2:
|
||||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
raise ValueError("Argument to Householder product must have ndims >= 2")
|
||||||
*batch_dims, m, n = a.shape
|
*batch_dims, m, n = a.shape
|
||||||
*taus_batch_dims, k = taus.shape
|
*taus_batch_dims, k = taus.shape
|
||||||
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
|
if a.dtype != taus.dtype or batch_dims != taus_batch_dims or k > min(m, n):
|
||||||
raise ValueError(f"Type mismatch for orgqr: a={a} taus={taus}")
|
raise ValueError(f"Type mismatch for Householder product: a={a} taus={taus}")
|
||||||
|
if m < n:
|
||||||
|
raise ValueError("Householder product inputs must have at least as many "
|
||||||
|
f"rows as columns, got shape {a.shape}")
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def _orgqr_batching_rule(batched_args, batch_dims):
|
def _householder_product_batching_rule(batched_args, batch_dims):
|
||||||
a, taus = batched_args
|
a, taus = batched_args
|
||||||
b_a, b_taus, = batch_dims
|
b_a, b_taus, = batch_dims
|
||||||
return orgqr(batching.moveaxis(a, b_a, 0),
|
return householder_product(batching.moveaxis(a, b_a, 0),
|
||||||
batching.moveaxis(taus, b_taus, 0)), (0,)
|
batching.moveaxis(taus, b_taus, 0)), (0,)
|
||||||
|
|
||||||
def _orgqr_translation_rule(ctx, avals_in, avals_out, a, taus):
|
def _householder_product_translation_rule(ctx, avals_in, avals_out, a, taus):
|
||||||
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
|
return [xops.ProductOfElementaryHouseholderReflectors(a, taus)]
|
||||||
|
|
||||||
def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||||
a_aval, _ = ctx.avals_in
|
a_aval, _ = ctx.avals_in
|
||||||
*batch_dims, m, n = a_aval.shape
|
*batch_dims, m, n = a_aval.shape
|
||||||
|
|
||||||
@ -1420,22 +1423,23 @@ def _orgqr_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
|||||||
return [a]
|
return [a]
|
||||||
|
|
||||||
|
|
||||||
orgqr_p = Primitive('orgqr')
|
householder_product_p = Primitive('householder_product')
|
||||||
orgqr_p.def_impl(partial(xla.apply_primitive, orgqr_p))
|
householder_product_p.def_impl(partial(xla.apply_primitive, householder_product_p))
|
||||||
orgqr_p.def_abstract_eval(_orgqr_abstract_eval)
|
householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
|
||||||
batching.primitive_batchers[orgqr_p] = _orgqr_batching_rule
|
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
|
||||||
xla.register_translation(orgqr_p, _orgqr_translation_rule)
|
xla.register_translation(householder_product_p, _householder_product_translation_rule)
|
||||||
|
|
||||||
mlir.register_lowering(
|
mlir.register_lowering(
|
||||||
orgqr_p, partial(_orgqr_cpu_gpu_lowering, lapack.orgqr_mhlo),
|
householder_product_p,
|
||||||
|
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo),
|
||||||
platform='cpu')
|
platform='cpu')
|
||||||
mlir.register_lowering(
|
mlir.register_lowering(
|
||||||
orgqr_p,
|
householder_product_p,
|
||||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
||||||
platform='cuda')
|
platform='cuda')
|
||||||
mlir.register_lowering(
|
mlir.register_lowering(
|
||||||
orgqr_p,
|
householder_product_p,
|
||||||
partial(_orgqr_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
partial(_householder_product_cpu_gpu_lowering, gpu_solver.rocm_orgqr),
|
||||||
platform='rocm')
|
platform='rocm')
|
||||||
|
|
||||||
|
|
||||||
@ -1492,13 +1496,13 @@ def _qr_lowering(a, *, full_matrices):
|
|||||||
|
|
||||||
r, taus = geqrf(a)
|
r, taus = geqrf(a)
|
||||||
if m < n:
|
if m < n:
|
||||||
q = orgqr(r[..., :m, :m], taus)
|
q = householder_product(r[..., :m, :m], taus)
|
||||||
elif full_matrices:
|
elif full_matrices:
|
||||||
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
|
pads = [(0, 0, 0)] * (len(batch_dims) + 1) + [(0, m - n, 0)]
|
||||||
q = lax.pad(r, lax_internal._zero(r), pads)
|
q = lax.pad(r, lax_internal._zero(r), pads)
|
||||||
q = orgqr(q, taus)
|
q = householder_product(q, taus)
|
||||||
else:
|
else:
|
||||||
q = orgqr(r, taus)
|
q = householder_product(r, taus)
|
||||||
r = r[..., :n, :n]
|
r = r[..., :n, :n]
|
||||||
r = jnp.triu(r)
|
r = jnp.triu(r)
|
||||||
return q, r
|
return q, r
|
||||||
@ -1920,6 +1924,146 @@ batching.primitive_batchers[schur_p] = _schur_batching_rule
|
|||||||
ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
ad.primitive_jvps[schur_p] = _schur_jvp_rule
|
||||||
|
|
||||||
|
|
||||||
|
# hessenberg: Upper Hessenberg reduction
|
||||||
|
|
||||||
|
def hessenberg(a: ArrayLike) -> Tuple[Array, Array]:
|
||||||
|
"""Reduces a square matrix to upper Hessenberg form.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: A floating point or complex square matrix or batch of matrices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``(a, taus)`` pair, where the upper triangle and first subdiagonal of ``a``
|
||||||
|
contain the upper Hessenberg matrix, and the elements below the first
|
||||||
|
subdiagonal contain the Householder reflectors. For each Householder
|
||||||
|
reflector ``taus`` contains the scalar factors of the elementary Householder
|
||||||
|
reflectors.
|
||||||
|
"""
|
||||||
|
return hessenberg_p.bind(a)
|
||||||
|
|
||||||
|
def _hessenberg_abstract_eval(a):
|
||||||
|
if a.ndim < 2:
|
||||||
|
msg = "hessenberg requires a.ndim to be at least 2, got {}."
|
||||||
|
raise TypeError(msg.format(a.ndim))
|
||||||
|
if a.shape[-1] != a.shape[-2]:
|
||||||
|
msg = ("hessenberg requires the last two dimensions of a to be equal "
|
||||||
|
"in size, got a.shape of {}.")
|
||||||
|
raise TypeError(msg.format(a.shape))
|
||||||
|
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype)]
|
||||||
|
|
||||||
|
hessenberg_p = Primitive("hessenberg")
|
||||||
|
hessenberg_p.def_impl(partial(xla.apply_primitive, hessenberg_p))
|
||||||
|
hessenberg_p.def_abstract_eval(_hessenberg_abstract_eval)
|
||||||
|
hessenberg_p.multiple_results = True
|
||||||
|
|
||||||
|
def _hessenberg_batching_rule(batched_args, batch_dims):
|
||||||
|
x, = batched_args
|
||||||
|
bd, = batch_dims
|
||||||
|
x = batching.moveaxis(x, bd, 0)
|
||||||
|
return hessenberg(x), 0
|
||||||
|
|
||||||
|
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
|
||||||
|
|
||||||
|
def _hessenberg_cpu_mhlo(ctx, a):
|
||||||
|
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
|
||||||
|
if not hasattr(lapack, "gehrd_mhlo"):
|
||||||
|
raise RuntimeError("Hessenberg reduction on CPU requires jaxlib 0.3.25 or "
|
||||||
|
"newer")
|
||||||
|
a_aval, = ctx.avals_in
|
||||||
|
batch_dims = a_aval.shape[:-2]
|
||||||
|
a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a)
|
||||||
|
ok = mlir.compare_mhlo(
|
||||||
|
info, mlir.full_like_aval(0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||||
|
"EQ", "SIGNED")
|
||||||
|
return [
|
||||||
|
_broadcasting_select_mhlo(
|
||||||
|
mhlo.BroadcastInDimOp(
|
||||||
|
ir.RankedTensorType.get(batch_dims + (1, 1),
|
||||||
|
ir.IntegerType.get_signless(1)),
|
||||||
|
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
|
||||||
|
a, _nan_like_mhlo(ctx.avals_out[0])),
|
||||||
|
_broadcasting_select_mhlo(
|
||||||
|
mhlo.BroadcastInDimOp(
|
||||||
|
ir.RankedTensorType.get(batch_dims + (1,),
|
||||||
|
ir.IntegerType.get_signless(1)),
|
||||||
|
ok, mlir.dense_int_elements(range(len(batch_dims)))).result,
|
||||||
|
taus, _nan_like_mhlo(ctx.avals_out[1])),
|
||||||
|
]
|
||||||
|
|
||||||
|
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
# tridiagonal: Upper Hessenberg reduction
|
||||||
|
|
||||||
|
def tridiagonal(a: ArrayLike, *, lower=True) -> Tuple[Array, Array]:
|
||||||
|
"""Reduces a symmetric/Hermitian matrix to tridiagonal form.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: A floating point or complex matrix or batch of matrices.
|
||||||
|
lower: Describes which triangle of the input matrices to use.
|
||||||
|
The other triangle is ignored and not accessed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``(a, taus)`` pair. If ``lower=True``, the diagonal and first subdiagonal of
|
||||||
|
matrix (or batch of matrices) ``a`` contain the tridiagonal representation,
|
||||||
|
and elements below the first subdiagonal contain the elementary Householder
|
||||||
|
reflectors. If ``lower=False`` the diagonal and first superdiagonal of the
|
||||||
|
matrix contains the tridiagonal representation, and elements above the first
|
||||||
|
superdiagonal contain the elementary Householder reflectors.
|
||||||
|
``taus`` contains the scalar factors of the elementary Householder
|
||||||
|
reflectors.
|
||||||
|
"""
|
||||||
|
arr, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower)
|
||||||
|
nan = arr.dtype.type(jnp.nan)
|
||||||
|
if jnp.issubdtype(arr.dtype, np.complexfloating):
|
||||||
|
nan = nan + arr.dtype.type(jnp.nan * 1j)
|
||||||
|
arr = jnp.where((info == 0)[..., None, None], arr, nan)
|
||||||
|
taus = jnp.where((info == 0)[..., None], taus, nan)
|
||||||
|
return arr, taus
|
||||||
|
|
||||||
|
def _tridiagonal_abstract_eval(a, *, lower):
|
||||||
|
if a.ndim < 2:
|
||||||
|
msg = "tridiagonal requires a.ndim to be at least 2, got {}."
|
||||||
|
raise TypeError(msg.format(a.ndim))
|
||||||
|
if a.shape[-1] != a.shape[-2]:
|
||||||
|
msg = ("tridiagonal requires the last two dimensions of a to be equal "
|
||||||
|
"in size, got a.shape of {}.")
|
||||||
|
raise TypeError(msg.format(a.shape))
|
||||||
|
if a.shape[-1] == 0:
|
||||||
|
msg = ("tridiagonal requires the last two dimensions of a to be non-zero, "
|
||||||
|
"got a.shape of {}.")
|
||||||
|
raise TypeError(msg.format(a.shape))
|
||||||
|
return [a, ShapedArray(a.shape[:-2] + (a.shape[-1] - 1,), a.dtype),
|
||||||
|
ShapedArray(a.shape[:-2], np.int32)]
|
||||||
|
|
||||||
|
tridiagonal_p = Primitive("tridiagonal")
|
||||||
|
tridiagonal_p.def_impl(partial(xla.apply_primitive, tridiagonal_p))
|
||||||
|
tridiagonal_p.def_abstract_eval(_tridiagonal_abstract_eval)
|
||||||
|
tridiagonal_p.multiple_results = True
|
||||||
|
|
||||||
|
def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
||||||
|
x, = batched_args
|
||||||
|
bd, = batch_dims
|
||||||
|
x = batching.moveaxis(x, bd, 0)
|
||||||
|
return tridiagonal(x), 0
|
||||||
|
|
||||||
|
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
||||||
|
|
||||||
|
def _tridiagonal_cpu_mhlo(ctx, a, *, lower):
|
||||||
|
a_aval, = ctx.avals_in
|
||||||
|
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
|
||||||
|
if not hasattr(lapack, "sytrd_mhlo"):
|
||||||
|
raise RuntimeError("Tridiagonal reduction on CPU requires jaxlib 0.3.25 or "
|
||||||
|
"newer")
|
||||||
|
|
||||||
|
a, d, e, taus, info = lapack.sytrd_mhlo(a_aval.dtype, a, lower=lower)
|
||||||
|
del d, e
|
||||||
|
return a, taus, info
|
||||||
|
|
||||||
|
mlir.register_lowering(tridiagonal_p, _tridiagonal_cpu_mhlo, platform='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
|
|
||||||
def _nan_like_mhlo(aval):
|
def _nan_like_mhlo(aval):
|
||||||
|
@ -1001,3 +1001,34 @@ def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> Tuple[Arra
|
|||||||
return T, Z
|
return T, Z
|
||||||
|
|
||||||
return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z))
|
return lax.fori_loop(1, N, _rsf2scf_iter, (T, Z))
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = False,
|
||||||
|
check_finite: bool = True) -> Array: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False,
|
||||||
|
check_finite: bool = True) -> Tuple[Array, Array]: ...
|
||||||
|
|
||||||
|
@_wraps(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc)
|
||||||
|
@partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a'))
|
||||||
|
def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False,
|
||||||
|
check_finite: bool = True) -> Union[Array, Tuple[Array, Array]]:
|
||||||
|
del overwrite_a, check_finite
|
||||||
|
n = jnp.shape(a)[-1]
|
||||||
|
if n == 0:
|
||||||
|
if calc_q:
|
||||||
|
return jnp.zeros_like(a), jnp.zeros_like(a)
|
||||||
|
else:
|
||||||
|
return jnp.zeros_like(a)
|
||||||
|
a_out, taus = lax_linalg.hessenberg(a)
|
||||||
|
h = jnp.triu(a_out, -1)
|
||||||
|
if calc_q:
|
||||||
|
q = lax_linalg.householder_product(a_out[..., 1:, :-1], taus)
|
||||||
|
batch_dims = a_out.shape[:-2]
|
||||||
|
q = jnp.block([[jnp.ones(batch_dims + (1, 1), dtype=a_out.dtype),
|
||||||
|
jnp.zeros(batch_dims + (1, n - 1), dtype=a_out.dtype)],
|
||||||
|
[jnp.zeros(batch_dims + (n - 1, 1), dtype=a_out.dtype), q]])
|
||||||
|
return h, q
|
||||||
|
else:
|
||||||
|
return h
|
||||||
|
@ -1253,7 +1253,9 @@ tf_not_yet_impl = [
|
|||||||
"lu_pivots_to_permutation",
|
"lu_pivots_to_permutation",
|
||||||
"xla_pmap",
|
"xla_pmap",
|
||||||
"geqrf",
|
"geqrf",
|
||||||
"orgqr",
|
"householder_product",
|
||||||
|
"hessenberg",
|
||||||
|
"tridiagonal",
|
||||||
"eigh_jacobi",
|
"eigh_jacobi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -19,15 +19,21 @@ from jax._src.lax.linalg import (
|
|||||||
eig_p,
|
eig_p,
|
||||||
eigh,
|
eigh,
|
||||||
eigh_p,
|
eigh_p,
|
||||||
|
hessenberg,
|
||||||
|
hessenberg_p,
|
||||||
lu,
|
lu,
|
||||||
lu_p,
|
lu_p,
|
||||||
lu_pivots_to_permutation,
|
lu_pivots_to_permutation,
|
||||||
|
householder_product,
|
||||||
|
householder_product_p,
|
||||||
qr,
|
qr,
|
||||||
qr_p,
|
qr_p,
|
||||||
svd,
|
svd,
|
||||||
svd_p,
|
svd_p,
|
||||||
triangular_solve,
|
triangular_solve,
|
||||||
triangular_solve_p,
|
triangular_solve_p,
|
||||||
|
tridiagonal,
|
||||||
|
tridiagonal_p,
|
||||||
tridiagonal_solve,
|
tridiagonal_solve,
|
||||||
tridiagonal_solve_p,
|
tridiagonal_solve_p,
|
||||||
schur,
|
schur,
|
||||||
|
@ -22,6 +22,7 @@ from jax._src.scipy.linalg import (
|
|||||||
eigh_tridiagonal as eigh_tridiagonal,
|
eigh_tridiagonal as eigh_tridiagonal,
|
||||||
expm as expm,
|
expm as expm,
|
||||||
expm_frechet as expm_frechet,
|
expm_frechet as expm_frechet,
|
||||||
|
hessenberg as hessenberg,
|
||||||
inv as inv,
|
inv as inv,
|
||||||
lu as lu,
|
lu as lu,
|
||||||
lu_factor as lu_factor,
|
lu_factor as lu_factor,
|
||||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
|
||||||
#include "jaxlib/cpu/lapack_kernels.h"
|
#include "jaxlib/cpu/lapack_kernels.h"
|
||||||
|
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||||
#include "include/pybind11/pybind11.h"
|
#include "include/pybind11/pybind11.h"
|
||||||
|
|
||||||
namespace jax {
|
namespace jax {
|
||||||
@ -127,6 +127,27 @@ void GetLapackKernelsFromScipy() {
|
|||||||
ComplexGees<std::complex<double>>::fn =
|
ComplexGees<std::complex<double>>::fn =
|
||||||
reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>(
|
reinterpret_cast<ComplexGees<std::complex<double>>::FnType*>(
|
||||||
lapack_ptr("zgees"));
|
lapack_ptr("zgees"));
|
||||||
|
Gehrd<float>::fn =
|
||||||
|
reinterpret_cast<Gehrd<float>::FnType*>(lapack_ptr("sgehrd"));
|
||||||
|
Gehrd<double>::fn =
|
||||||
|
reinterpret_cast<Gehrd<double>::FnType*>(lapack_ptr("dgehrd"));
|
||||||
|
Gehrd<std::complex<float>>::fn =
|
||||||
|
reinterpret_cast<Gehrd<std::complex<float>>::FnType*>(
|
||||||
|
lapack_ptr("cgehrd"));
|
||||||
|
Gehrd<std::complex<double>>::fn =
|
||||||
|
reinterpret_cast<Gehrd<std::complex<double>>::FnType*>(
|
||||||
|
lapack_ptr("zgehrd"));
|
||||||
|
Sytrd<float>::fn =
|
||||||
|
reinterpret_cast<Sytrd<float>::FnType*>(lapack_ptr("ssytrd"));
|
||||||
|
Sytrd<double>::fn =
|
||||||
|
reinterpret_cast<Sytrd<double>::FnType*>(lapack_ptr("dsytrd"));
|
||||||
|
Sytrd<std::complex<float>>::fn =
|
||||||
|
reinterpret_cast<Sytrd<std::complex<float>>::FnType*>(
|
||||||
|
lapack_ptr("chetrd"));
|
||||||
|
Sytrd<std::complex<double>>::fn =
|
||||||
|
reinterpret_cast<Sytrd<std::complex<double>>::FnType*>(
|
||||||
|
lapack_ptr("zhetrd"));
|
||||||
|
|
||||||
initialized = true;
|
initialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,6 +206,21 @@ py::dict Registrations() {
|
|||||||
EncapsulateFunction(ComplexGees<std::complex<float>>::Kernel);
|
EncapsulateFunction(ComplexGees<std::complex<float>>::Kernel);
|
||||||
dict["lapack_zgees"] =
|
dict["lapack_zgees"] =
|
||||||
EncapsulateFunction(ComplexGees<std::complex<double>>::Kernel);
|
EncapsulateFunction(ComplexGees<std::complex<double>>::Kernel);
|
||||||
|
|
||||||
|
dict["lapack_sgehrd"] = EncapsulateFunction(Gehrd<float>::Kernel);
|
||||||
|
dict["lapack_dgehrd"] = EncapsulateFunction(Gehrd<double>::Kernel);
|
||||||
|
dict["lapack_cgehrd"] =
|
||||||
|
EncapsulateFunction(Gehrd<std::complex<float>>::Kernel);
|
||||||
|
dict["lapack_zgehrd"] =
|
||||||
|
EncapsulateFunction(Gehrd<std::complex<double>>::Kernel);
|
||||||
|
|
||||||
|
dict["lapack_ssytrd"] = EncapsulateFunction(Sytrd<float>::Kernel);
|
||||||
|
dict["lapack_dsytrd"] = EncapsulateFunction(Sytrd<double>::Kernel);
|
||||||
|
dict["lapack_chetrd"] =
|
||||||
|
EncapsulateFunction(Sytrd<std::complex<float>>::Kernel);
|
||||||
|
dict["lapack_zhetrd"] =
|
||||||
|
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);
|
||||||
|
|
||||||
return dict;
|
return dict;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,6 +247,14 @@ PYBIND11_MODULE(_lapack, m) {
|
|||||||
m.def("syevd_iwork_size", &SyevdIworkSize);
|
m.def("syevd_iwork_size", &SyevdIworkSize);
|
||||||
m.def("heevd_work_size", &HeevdWorkSize);
|
m.def("heevd_work_size", &HeevdWorkSize);
|
||||||
m.def("heevd_rwork_size", &HeevdRworkSize);
|
m.def("heevd_rwork_size", &HeevdRworkSize);
|
||||||
|
m.def("lapack_sgehrd_workspace", &Gehrd<float>::Workspace);
|
||||||
|
m.def("lapack_dgehrd_workspace", &Gehrd<double>::Workspace);
|
||||||
|
m.def("lapack_cgehrd_workspace", &Gehrd<std::complex<float>>::Workspace);
|
||||||
|
m.def("lapack_zgehrd_workspace", &Gehrd<std::complex<double>>::Workspace);
|
||||||
|
m.def("lapack_ssytrd_workspace", &Sytrd<float>::Workspace);
|
||||||
|
m.def("lapack_dsytrd_workspace", &Sytrd<double>::Workspace);
|
||||||
|
m.def("lapack_chetrd_workspace", &Sytrd<std::complex<float>>::Workspace);
|
||||||
|
m.def("lapack_zhetrd_workspace", &Sytrd<std::complex<double>>::Workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -775,4 +775,111 @@ template struct RealGees<double>;
|
|||||||
template struct ComplexGees<std::complex<float>>;
|
template struct ComplexGees<std::complex<float>>;
|
||||||
template struct ComplexGees<std::complex<double>>;
|
template struct ComplexGees<std::complex<double>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
typename Gehrd<T>::FnType* Gehrd<T>::fn = nullptr;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Gehrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||||
|
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
|
||||||
|
int32_t ilo = *reinterpret_cast<int32_t*>(data[1]);
|
||||||
|
int32_t ihi = *reinterpret_cast<int32_t*>(data[2]);
|
||||||
|
int32_t lda = *reinterpret_cast<int32_t*>(data[3]);
|
||||||
|
int32_t batch = *reinterpret_cast<int32_t*>(data[4]);
|
||||||
|
int32_t lwork = *reinterpret_cast<int32_t*>(data[5]);
|
||||||
|
T* a = reinterpret_cast<T*>(data[6]);
|
||||||
|
|
||||||
|
void** out = reinterpret_cast<void**>(out_tuple);
|
||||||
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
||||||
|
T* tau = reinterpret_cast<T*>(out[1]);
|
||||||
|
int* info = reinterpret_cast<int*>(out[2]);
|
||||||
|
T* work = reinterpret_cast<T*>(out[3]);
|
||||||
|
|
||||||
|
if (a_out != a) {
|
||||||
|
std::memcpy(a_out, a,
|
||||||
|
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
|
||||||
|
static_cast<int64_t>(n) * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
|
||||||
|
|
||||||
|
for (int i = 0; i < batch; ++i) {
|
||||||
|
fn(&n, &ilo, &ihi, a_out, &lda, tau, work, &lwork, info);
|
||||||
|
a_out += a_plus;
|
||||||
|
tau += n - 1;
|
||||||
|
++info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int64_t Gehrd<T>::Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
|
||||||
|
lapack_int ihi) {
|
||||||
|
T work = 0;
|
||||||
|
lapack_int lwork = -1;
|
||||||
|
lapack_int info = 0;
|
||||||
|
fn(&n, &ilo, &ihi, nullptr, &lda, nullptr, &work, &lwork, &info);
|
||||||
|
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template struct Gehrd<float>;
|
||||||
|
template struct Gehrd<double>;
|
||||||
|
template struct Gehrd<std::complex<float>>;
|
||||||
|
template struct Gehrd<std::complex<double>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
typename Sytrd<T>::FnType* Sytrd<T>::fn = nullptr;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void Sytrd<T>::Kernel(void* out_tuple, void** data, XlaCustomCallStatus*) {
|
||||||
|
int32_t n = *reinterpret_cast<int32_t*>(data[0]);
|
||||||
|
int32_t lower = *reinterpret_cast<int32_t*>(data[1]);
|
||||||
|
int32_t lda = *reinterpret_cast<int32_t*>(data[2]);
|
||||||
|
int32_t batch = *reinterpret_cast<int32_t*>(data[3]);
|
||||||
|
int32_t lwork = *reinterpret_cast<int32_t*>(data[4]);
|
||||||
|
T* a = reinterpret_cast<T*>(data[5]);
|
||||||
|
|
||||||
|
void** out = reinterpret_cast<void**>(out_tuple);
|
||||||
|
T* a_out = reinterpret_cast<T*>(out[0]);
|
||||||
|
typedef typename real_type<T>::type Real;
|
||||||
|
Real* d = reinterpret_cast<Real*>(out[1]);
|
||||||
|
Real* e = reinterpret_cast<Real*>(out[2]);
|
||||||
|
T* tau = reinterpret_cast<T*>(out[3]);
|
||||||
|
int* info = reinterpret_cast<int*>(out[4]);
|
||||||
|
T* work = reinterpret_cast<T*>(out[5]);
|
||||||
|
|
||||||
|
if (a_out != a) {
|
||||||
|
std::memcpy(a_out, a,
|
||||||
|
static_cast<int64_t>(batch) * static_cast<int64_t>(n) *
|
||||||
|
static_cast<int64_t>(n) * sizeof(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
char cuplo = lower ? 'L' : 'U';
|
||||||
|
|
||||||
|
int64_t a_plus = static_cast<int64_t>(lda) * static_cast<int64_t>(n);
|
||||||
|
|
||||||
|
for (int i = 0; i < batch; ++i) {
|
||||||
|
fn(&cuplo, &n, a_out, &lda, d, e, tau, work, &lwork, info);
|
||||||
|
a_out += a_plus;
|
||||||
|
d += n;
|
||||||
|
e += n - 1;
|
||||||
|
tau += n - 1;
|
||||||
|
++info;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int64_t Sytrd<T>::Workspace(lapack_int lda, lapack_int n) {
|
||||||
|
char cuplo = 'L';
|
||||||
|
T work = 0;
|
||||||
|
lapack_int lwork = -1;
|
||||||
|
lapack_int info = 0;
|
||||||
|
fn(&cuplo, &n, nullptr, &lda, nullptr, nullptr, nullptr, &work, &lwork,
|
||||||
|
&info);
|
||||||
|
return info == 0 ? static_cast<int64_t>(std::real(work)) : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
template struct Sytrd<float>;
|
||||||
|
template struct Sytrd<double>;
|
||||||
|
template struct Sytrd<std::complex<float>>;
|
||||||
|
template struct Sytrd<std::complex<double>>;
|
||||||
|
|
||||||
} // namespace jax
|
} // namespace jax
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||||
|
|
||||||
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
|
// Underlying function pointers (e.g., Trsm<double>::Fn) are initialized either
|
||||||
@ -88,8 +89,8 @@ struct RealGesdd {
|
|||||||
static FnType* fn;
|
static FnType* fn;
|
||||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||||
|
|
||||||
static int64_t Workspace(lapack_int m, lapack_int n,
|
static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
|
||||||
bool job_opt_compute_uv, bool job_opt_full_matrices);
|
bool job_opt_full_matrices);
|
||||||
};
|
};
|
||||||
|
|
||||||
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv);
|
lapack_int ComplexGesddRworkSize(int64_t m, int64_t n, int compute_uv);
|
||||||
@ -104,11 +105,10 @@ struct ComplexGesdd {
|
|||||||
static FnType* fn;
|
static FnType* fn;
|
||||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||||
|
|
||||||
static int64_t Workspace(lapack_int m, lapack_int n,
|
static int64_t Workspace(lapack_int m, lapack_int n, bool job_opt_compute_uv,
|
||||||
bool job_opt_compute_uv, bool job_opt_full_matrices);
|
bool job_opt_full_matrices);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
lapack_int SyevdWorkSize(int64_t n);
|
lapack_int SyevdWorkSize(int64_t n);
|
||||||
lapack_int SyevdIworkSize(int64_t n);
|
lapack_int SyevdIworkSize(int64_t n);
|
||||||
|
|
||||||
@ -176,6 +176,45 @@ struct ComplexGees {
|
|||||||
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form.
|
||||||
|
template <typename T>
|
||||||
|
struct Gehrd {
|
||||||
|
using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, T* a,
|
||||||
|
lapack_int* lda, T* tau, T* work, lapack_int* lwork,
|
||||||
|
lapack_int* info);
|
||||||
|
|
||||||
|
static FnType* fn;
|
||||||
|
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||||
|
|
||||||
|
static int64_t Workspace(lapack_int lda, lapack_int n, lapack_int ilo,
|
||||||
|
lapack_int ihi);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct real_type {
|
||||||
|
typedef T type;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
struct real_type<std::complex<T>> {
|
||||||
|
typedef T type;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal
|
||||||
|
// form.
|
||||||
|
template <typename T>
|
||||||
|
struct Sytrd {
|
||||||
|
using FnType = void(char* uplo, lapack_int* n, T* a, lapack_int* lda,
|
||||||
|
typename real_type<T>::type* d,
|
||||||
|
typename real_type<T>::type* e,
|
||||||
|
T* tau, T* work,
|
||||||
|
lapack_int* lwork, lapack_int* info);
|
||||||
|
|
||||||
|
static FnType* fn;
|
||||||
|
static void Kernel(void* out, void** data, XlaCustomCallStatus*);
|
||||||
|
|
||||||
|
static int64_t Workspace(lapack_int lda, lapack_int n);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace jax
|
} // namespace jax
|
||||||
|
|
||||||
#endif // JAXLIB_CPU_LAPACK_KERNELS_H_
|
#endif // JAXLIB_CPU_LAPACK_KERNELS_H_
|
||||||
|
@ -66,6 +66,16 @@ jax::RealGees<double>::FnType dgees_;
|
|||||||
jax::ComplexGees<std::complex<float>>::FnType cgees_;
|
jax::ComplexGees<std::complex<float>>::FnType cgees_;
|
||||||
jax::ComplexGees<std::complex<double>>::FnType zgees_;
|
jax::ComplexGees<std::complex<double>>::FnType zgees_;
|
||||||
|
|
||||||
|
jax::Gehrd<float>::FnType sgehrd_;
|
||||||
|
jax::Gehrd<double>::FnType dgehrd_;
|
||||||
|
jax::Gehrd<std::complex<float>>::FnType cgehrd_;
|
||||||
|
jax::Gehrd<std::complex<double>>::FnType zgehrd_;
|
||||||
|
|
||||||
|
jax::Sytrd<float>::FnType ssytrd_;
|
||||||
|
jax::Sytrd<double>::FnType dsytrd_;
|
||||||
|
jax::Sytrd<std::complex<float>>::FnType chetrd_;
|
||||||
|
jax::Sytrd<std::complex<double>>::FnType zhetrd_;
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
namespace jax {
|
namespace jax {
|
||||||
@ -107,6 +117,15 @@ static auto init = []() -> int {
|
|||||||
RealGees<double>::fn = dgees_;
|
RealGees<double>::fn = dgees_;
|
||||||
ComplexGees<std::complex<float>>::fn = cgees_;
|
ComplexGees<std::complex<float>>::fn = cgees_;
|
||||||
ComplexGees<std::complex<double>>::fn = zgees_;
|
ComplexGees<std::complex<double>>::fn = zgees_;
|
||||||
|
Gehrd<float>::fn = sgehrd_;
|
||||||
|
Gehrd<double>::fn = dgehrd_;
|
||||||
|
Gehrd<std::complex<float>>::fn = cgehrd_;
|
||||||
|
Gehrd<std::complex<double>>::fn = zgehrd_;
|
||||||
|
Sytrd<float>::fn = ssytrd_;
|
||||||
|
Sytrd<double>::fn = dsytrd_;
|
||||||
|
Sytrd<std::complex<float>>::fn = chetrd_;
|
||||||
|
Sytrd<std::complex<double>>::fn = zhetrd_;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}();
|
}();
|
||||||
|
|
||||||
|
114
jaxlib/lapack.py
114
jaxlib/lapack.py
@ -628,3 +628,117 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
|
|||||||
return (out[0], out[3], out[4], out[5])
|
return (out[0], out[3], out[4], out[5])
|
||||||
else:
|
else:
|
||||||
return (out[0], out[3], out[5])
|
return (out[0], out[3], out[5])
|
||||||
|
|
||||||
|
|
||||||
|
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
|
||||||
|
def gehrd_mhlo(dtype, a):
|
||||||
|
_initialize()
|
||||||
|
a_type = ir.RankedTensorType(a.type)
|
||||||
|
dims = a_type.shape
|
||||||
|
assert len(dims) >= 2
|
||||||
|
m, n = dims[-2:]
|
||||||
|
assert m == n, (m, n)
|
||||||
|
batch_dims = tuple(dims[:-2])
|
||||||
|
num_bd = len(batch_dims)
|
||||||
|
b = 1
|
||||||
|
for d in batch_dims:
|
||||||
|
b *= d
|
||||||
|
|
||||||
|
if dtype == np.float32:
|
||||||
|
fn = b"lapack_sgehrd"
|
||||||
|
lwork = _lapack.lapack_sgehrd_workspace(n, n, 1, n)
|
||||||
|
elif dtype == np.float64:
|
||||||
|
fn = b"lapack_dgehrd"
|
||||||
|
lwork = _lapack.lapack_dgehrd_workspace(n, n, 1, n)
|
||||||
|
elif dtype == np.complex64:
|
||||||
|
fn = b"lapack_cgehrd"
|
||||||
|
lwork = _lapack.lapack_cgehrd_workspace(n, n, 1, n)
|
||||||
|
elif dtype == np.complex128:
|
||||||
|
fn = b"lapack_zgehrd"
|
||||||
|
lwork = _lapack.lapack_zgehrd_workspace(n, n, 1, n)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||||
|
|
||||||
|
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||||
|
i32_type = ir.IntegerType.get_signless(32)
|
||||||
|
out = custom_call(
|
||||||
|
fn,
|
||||||
|
[
|
||||||
|
a.type,
|
||||||
|
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
|
||||||
|
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||||
|
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||||
|
],
|
||||||
|
[_mhlo_s32(n), _mhlo_s32(1), _mhlo_s32(n), _mhlo_s32(n), _mhlo_s32(b),
|
||||||
|
_mhlo_s32(lwork), a],
|
||||||
|
operand_layouts=[[]] * 6 + [layout],
|
||||||
|
result_layouts=[
|
||||||
|
layout,
|
||||||
|
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||||
|
tuple(range(num_bd - 1, -1, -1)),
|
||||||
|
[0],
|
||||||
|
],
|
||||||
|
operand_output_aliases={6: 0},
|
||||||
|
)
|
||||||
|
return out[:3]
|
||||||
|
|
||||||
|
|
||||||
|
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
|
||||||
|
def sytrd_mhlo(dtype, a, *, lower):
|
||||||
|
_initialize()
|
||||||
|
a_type = ir.RankedTensorType(a.type)
|
||||||
|
dims = a_type.shape
|
||||||
|
assert len(dims) >= 2
|
||||||
|
m, n = dims[-2:]
|
||||||
|
assert m == n, (m, n)
|
||||||
|
batch_dims = tuple(dims[:-2])
|
||||||
|
num_bd = len(batch_dims)
|
||||||
|
b = 1
|
||||||
|
for d in batch_dims:
|
||||||
|
b *= d
|
||||||
|
|
||||||
|
if dtype == np.float32:
|
||||||
|
fn = b"lapack_ssytrd"
|
||||||
|
lwork = _lapack.lapack_ssytrd_workspace(n, n)
|
||||||
|
diag_type = a_type.element_type
|
||||||
|
elif dtype == np.float64:
|
||||||
|
fn = b"lapack_dsytrd"
|
||||||
|
lwork = _lapack.lapack_dsytrd_workspace(n, n)
|
||||||
|
diag_type = a_type.element_type
|
||||||
|
elif dtype == np.complex64:
|
||||||
|
fn = b"lapack_chetrd"
|
||||||
|
lwork = _lapack.lapack_chetrd_workspace(n, n)
|
||||||
|
diag_type = ir.ComplexType.get(ir.F32Type.get())
|
||||||
|
elif dtype == np.complex128:
|
||||||
|
fn = b"lapack_zhetrd"
|
||||||
|
lwork = _lapack.lapack_zhetrd_workspace(n, n)
|
||||||
|
diag_type = ir.ComplexType.get(ir.F64Type.get())
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||||
|
|
||||||
|
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||||
|
i32_type = ir.IntegerType.get_signless(32)
|
||||||
|
out = custom_call(
|
||||||
|
fn,
|
||||||
|
[
|
||||||
|
a.type,
|
||||||
|
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
|
||||||
|
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
|
||||||
|
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
|
||||||
|
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||||
|
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||||
|
],
|
||||||
|
[_mhlo_s32(n), _mhlo_s32(1 if lower else 0), _mhlo_s32(max(1, n)),
|
||||||
|
_mhlo_s32(b), _mhlo_s32(lwork), a],
|
||||||
|
operand_layouts=[[]] * 5 + [layout],
|
||||||
|
result_layouts=[
|
||||||
|
layout,
|
||||||
|
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||||
|
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||||
|
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
|
||||||
|
tuple(range(num_bd - 1, -1, -1)),
|
||||||
|
[0],
|
||||||
|
],
|
||||||
|
operand_output_aliases={5: 0},
|
||||||
|
)
|
||||||
|
return out[:5]
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
"""Tests for the LAPAX linear algebra module."""
|
"""Tests for the LAPAX linear algebra module."""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
@ -28,6 +29,7 @@ from jax import jit, grad, jvp, vmap
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax import scipy as jsp
|
from jax import scipy as jsp
|
||||||
|
from jax._src.lib import version as jaxlib_version
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
|
|
||||||
from jax.config import config
|
from jax.config import config
|
||||||
@ -1244,6 +1246,96 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
|||||||
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
|
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1E-5, atol=1E-5)
|
||||||
self._CompileAndCheck(jsp_func, args_maker)
|
self._CompileAndCheck(jsp_func, args_maker)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
[dict(shape=shape, k=k)
|
||||||
|
for shape in [(1, 1), (3, 4, 4), (10, 5)]
|
||||||
|
# TODO(phawkins): there are some test failures on GPU for k=0
|
||||||
|
for k in range(1, shape[-1] + 1)],
|
||||||
|
dtype=float_types + complex_types,
|
||||||
|
)
|
||||||
|
def testHouseholderProduct(self, shape, k, dtype):
|
||||||
|
|
||||||
|
@partial(np.vectorize, signature='(m,n),(k)->(m,n)')
|
||||||
|
def reference_fn(a, taus):
|
||||||
|
if dtype == np.float32:
|
||||||
|
q, _, info = scipy.linalg.lapack.sorgqr(a, taus)
|
||||||
|
elif dtype == np.float64:
|
||||||
|
q, _, info = scipy.linalg.lapack.dorgqr(a, taus)
|
||||||
|
elif dtype == np.complex64:
|
||||||
|
q, _, info = scipy.linalg.lapack.cungqr(a, taus)
|
||||||
|
elif dtype == np.complex128:
|
||||||
|
q, _, info = scipy.linalg.lapack.zungqr(a, taus)
|
||||||
|
else:
|
||||||
|
assert False, dtype
|
||||||
|
assert info == 0, info
|
||||||
|
return q
|
||||||
|
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
args_maker = lambda: [rng(shape, dtype), rng(shape[:-2] + (k,), dtype)]
|
||||||
|
tol = {np.float32: 1e-5, np.complex64: 1e-5, np.float64: 1e-12,
|
||||||
|
np.complex128: 1e-12}
|
||||||
|
self._CheckAgainstNumpy(reference_fn, lax.linalg.householder_product,
|
||||||
|
args_maker, rtol=tol, atol=tol)
|
||||||
|
self._CompileAndCheck(lax.linalg.householder_product, args_maker)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
shape=[(1, 1), (2, 4, 4), (0, 100, 100), (10, 10)],
|
||||||
|
dtype=float_types + complex_types,
|
||||||
|
calc_q=[False, True],
|
||||||
|
)
|
||||||
|
@jtu.skip_on_devices("gpu", "tpu")
|
||||||
|
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
|
||||||
|
def testHessenberg(self, shape, dtype, calc_q):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
jsp_func = partial(jax.scipy.linalg.hessenberg, calc_q=calc_q)
|
||||||
|
if calc_q:
|
||||||
|
sp_func = np.vectorize(partial(scipy.linalg.hessenberg, calc_q=True),
|
||||||
|
otypes=(dtype, dtype),
|
||||||
|
signature='(n,n)->(n,n),(n,n)')
|
||||||
|
else:
|
||||||
|
sp_func = np.vectorize(scipy.linalg.hessenberg, signature='(n,n)->(n,n)',
|
||||||
|
otypes=(dtype,))
|
||||||
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
|
# scipy.linalg.hessenberg sometimes returns a float Q matrix for complex
|
||||||
|
# inputs
|
||||||
|
self._CheckAgainstNumpy(sp_func, jsp_func, args_maker, rtol=1e-5, atol=1e-5,
|
||||||
|
check_dtypes=not calc_q)
|
||||||
|
self._CompileAndCheck(jsp_func, args_maker)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
shape=[(1, 1), (4, 4), (10, 10)],
|
||||||
|
dtype=float_types + complex_types,
|
||||||
|
lower=[False, True],
|
||||||
|
)
|
||||||
|
@unittest.skipIf(jaxlib_version < (0, 3, 25), "Test requires jaxlib 0.3.25")
|
||||||
|
@jtu.skip_on_devices("gpu", "tpu")
|
||||||
|
def testTridiagonal(self, shape, dtype, lower):
|
||||||
|
rng = jtu.rand_default(self.rng())
|
||||||
|
def jax_func(a):
|
||||||
|
return lax.linalg.tridiagonal(a, lower=lower)
|
||||||
|
|
||||||
|
@partial(np.vectorize, otypes=(dtype, dtype),
|
||||||
|
signature='(n,n)->(n,n),(k)')
|
||||||
|
def sp_func(a):
|
||||||
|
if dtype == np.float32:
|
||||||
|
c, d, e, tau, info = scipy.linalg.lapack.ssytrd(a, lower=lower)
|
||||||
|
elif dtype == np.float64:
|
||||||
|
c, d, e, tau, info = scipy.linalg.lapack.dsytrd(a, lower=lower)
|
||||||
|
elif dtype == np.complex64:
|
||||||
|
c, d, e, tau, info = scipy.linalg.lapack.chetrd(a, lower=lower)
|
||||||
|
elif dtype == np.complex128:
|
||||||
|
c, d, e, tau, info = scipy.linalg.lapack.zhetrd(a, lower=lower)
|
||||||
|
else:
|
||||||
|
assert False, dtype
|
||||||
|
del d, e
|
||||||
|
assert info == 0
|
||||||
|
return c, tau
|
||||||
|
|
||||||
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
|
self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-5, atol=1e-5,
|
||||||
|
check_dtypes=False)
|
||||||
|
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
n=[1, 4, 5, 20, 50, 100],
|
n=[1, 4, 5, 20, 50, 100],
|
||||||
dtype=float_types + complex_types,
|
dtype=float_types + complex_types,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user