mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
This commit is contained in:
parent
d092d6305f
commit
705e241409
@ -13,6 +13,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Changes
|
||||
* {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument
|
||||
that allows users to opt out of eigenvalue sorting on TPU.
|
||||
* Non-array arguments to functions in {mod}`jax.lax.linalg` are now marked
|
||||
keyword-only. As a backward-compatibility step passing keyword-only
|
||||
arguments positionally yields a warning, but in a future JAX release passing
|
||||
keyword-only arguments positionally will fail.
|
||||
However, most users should prefer to use {mod}`jax.numpy.linalg` instead.
|
||||
|
||||
## jaxlib 0.3.11 (Unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
|
||||
|
@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import functools
|
||||
from functools import partial
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -58,7 +60,52 @@ xops = xla_client.ops
|
||||
|
||||
# traceables
|
||||
|
||||
def cholesky(x, symmetrize_input: bool = True):
|
||||
# TODO(phawkins): remove backward compatibility shim after 2022/08/11.
|
||||
def _warn_on_positional_kwargs(f):
|
||||
"""Decorator used for backward compatibility of keyword-only arguments.
|
||||
|
||||
Some functions were changed to mark their keyword arguments as keyword-only.
|
||||
This decorator allows existing code to keep working temporarily, while issuing
|
||||
a warning if a now keyword-only parameter is passed positionally."""
|
||||
sig = inspect.signature(f)
|
||||
pos_names = [name for name, p in sig.parameters.items()
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD]
|
||||
kwarg_names = [name for name, p in sig.parameters.items()
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY]
|
||||
|
||||
# This decorator assumes that all arguments to `f` are either
|
||||
# positional-or-keyword or keyword-only.
|
||||
assert len(pos_names) + len(kwarg_names) == len(sig.parameters)
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
if len(args) < len(pos_names):
|
||||
a = pos_names[len(args)]
|
||||
raise TypeError(f"{f.__name__} missing required positional argument: {a}")
|
||||
|
||||
pos_args = args[:len(pos_names)]
|
||||
extra_kwargs = args[len(pos_names):]
|
||||
|
||||
if len(extra_kwargs) > len(kwarg_names):
|
||||
raise TypeError(f"{f.__name__} takes at most {len(sig.parameters)} "
|
||||
f" arguments but {len(args)} were given.")
|
||||
|
||||
for name, value in zip(kwarg_names, extra_kwargs):
|
||||
if name in kwargs:
|
||||
raise TypeError(f"{f.__name__} got multiple values for argument: "
|
||||
f"{name}")
|
||||
|
||||
warnings.warn(f"Argument {name} to {f.__name__} is now a keyword-only "
|
||||
"argument. Support for passing it positionally will be "
|
||||
"removed in an upcoming JAX release.",
|
||||
DeprecationWarning)
|
||||
kwargs[name] = value
|
||||
return f(*pos_args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
@_warn_on_positional_kwargs
|
||||
def cholesky(x, *, symmetrize_input: bool = True):
|
||||
"""Cholesky decomposition.
|
||||
|
||||
Computes the Cholesky decomposition
|
||||
@ -87,7 +134,8 @@ def cholesky(x, symmetrize_input: bool = True):
|
||||
x = symmetrize(x)
|
||||
return jnp.tril(cholesky_p.bind(x))
|
||||
|
||||
def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
|
||||
@_warn_on_positional_kwargs
|
||||
def eig(x, *, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
|
||||
"""Eigendecomposition of a general matrix.
|
||||
|
||||
Nonsymmetric eigendecomposition is at present only implemented on CPU.
|
||||
@ -95,7 +143,8 @@ def eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True):
|
||||
return eig_p.bind(x, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||
|
||||
def eigh(x, lower: bool = True, symmetrize_input: bool = True,
|
||||
@_warn_on_positional_kwargs
|
||||
def eigh(x, *, lower: bool = True, symmetrize_input: bool = True,
|
||||
sort_eigenvalues: bool = True, ):
|
||||
r"""Eigendecomposition of a Hermitian matrix.
|
||||
|
||||
@ -182,7 +231,8 @@ def lu(x):
|
||||
lu, pivots, permutation = lu_p.bind(x)
|
||||
return lu, pivots, permutation
|
||||
|
||||
def qr(x, full_matrices: bool = True):
|
||||
@_warn_on_positional_kwargs
|
||||
def qr(x, *, full_matrices: bool = True):
|
||||
"""QR decomposition.
|
||||
|
||||
Computes the QR decomposition
|
||||
@ -213,7 +263,8 @@ def qr(x, full_matrices: bool = True):
|
||||
return q, r
|
||||
|
||||
# TODO: Add `max_qdwh_iterations` to the function signature for TPU SVD.
|
||||
def svd(x, full_matrices=True, compute_uv=True):
|
||||
@_warn_on_positional_kwargs
|
||||
def svd(x, *, full_matrices=True, compute_uv=True):
|
||||
"""Singular value decomposition.
|
||||
|
||||
Returns the singular values if compute_uv is False, otherwise returns a triple
|
||||
@ -228,7 +279,8 @@ def svd(x, full_matrices=True, compute_uv=True):
|
||||
s, = result
|
||||
return s
|
||||
|
||||
def triangular_solve(a, b, left_side: bool = False, lower: bool = False,
|
||||
@_warn_on_positional_kwargs
|
||||
def triangular_solve(a, b, *, left_side: bool = False, lower: bool = False,
|
||||
transpose_a: bool = False, conjugate_a: bool = False,
|
||||
unit_diagonal: bool = False):
|
||||
r"""Triangular solve.
|
||||
@ -330,7 +382,7 @@ _cpu_lapack_types = {np.dtype(np.float32), np.dtype(np.float64),
|
||||
|
||||
# Cholesky decomposition
|
||||
|
||||
def cholesky_jvp_rule(primals, tangents):
|
||||
def _cholesky_jvp_rule(primals, tangents):
|
||||
x, = primals
|
||||
sigma_dot, = tangents
|
||||
L = jnp.tril(cholesky_p.bind(x))
|
||||
@ -349,15 +401,15 @@ def cholesky_jvp_rule(primals, tangents):
|
||||
precision=lax.Precision.HIGHEST)
|
||||
return L, L_dot
|
||||
|
||||
def cholesky_batching_rule(batched_args, batch_dims):
|
||||
def _cholesky_batching_rule(batched_args, batch_dims):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
return cholesky(x), 0
|
||||
|
||||
cholesky_p = standard_unop(_float | _complex, 'cholesky')
|
||||
ad.primitive_jvps[cholesky_p] = cholesky_jvp_rule
|
||||
batching.primitive_batchers[cholesky_p] = cholesky_batching_rule
|
||||
ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
|
||||
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
|
||||
|
||||
def _cholesky_lowering(ctx, x):
|
||||
aval, = ctx.avals_out
|
||||
@ -635,7 +687,7 @@ triangular_solve_dtype_rule = partial(
|
||||
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
|
||||
'triangular_solve')
|
||||
|
||||
def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
|
||||
def triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
|
||||
if a.ndim < 2:
|
||||
msg = "triangular_solve requires a.ndim to be at least 2, got {}."
|
||||
raise TypeError(msg.format(a.ndim))
|
||||
@ -657,7 +709,8 @@ def triangular_solve_shape_rule(a, b, left_side=False, **unused_kwargs):
|
||||
return b.shape
|
||||
|
||||
def triangular_solve_jvp_rule_a(
|
||||
g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal):
|
||||
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
m, n = b.shape[-2:]
|
||||
k = 1 if unit_diagonal else 0
|
||||
g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k)
|
||||
@ -668,8 +721,9 @@ def triangular_solve_jvp_rule_a(
|
||||
precision=lax.Precision.HIGHEST)
|
||||
|
||||
def a_inverse(rhs):
|
||||
return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal)
|
||||
return triangular_solve(a, rhs, left_side=left_side, lower=lower,
|
||||
transpose_a=transpose_a, conjugate_a=conjugate_a,
|
||||
unit_diagonal=unit_diagonal)
|
||||
|
||||
# triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs
|
||||
# for matrix/vector inputs). Order these operations in whichever order is
|
||||
@ -688,7 +742,7 @@ def triangular_solve_jvp_rule_a(
|
||||
return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
|
||||
|
||||
def triangular_solve_transpose_rule(
|
||||
cotangent, a, b, left_side, lower, transpose_a, conjugate_a,
|
||||
cotangent, a, b, *, left_side, lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
# Triangular solve is nonlinear in its first argument and linear in its second
|
||||
# argument, analogous to `div` but swapped.
|
||||
@ -696,12 +750,14 @@ def triangular_solve_transpose_rule(
|
||||
if type(cotangent) is ad_util.Zero:
|
||||
cotangent_b = ad_util.Zero(b.aval)
|
||||
else:
|
||||
cotangent_b = triangular_solve(a, cotangent, left_side, lower,
|
||||
not transpose_a, conjugate_a, unit_diagonal)
|
||||
cotangent_b = triangular_solve(a, cotangent, left_side=left_side,
|
||||
lower=lower, transpose_a=not transpose_a,
|
||||
conjugate_a=conjugate_a,
|
||||
unit_diagonal=unit_diagonal)
|
||||
return [None, cotangent_b]
|
||||
|
||||
|
||||
def triangular_solve_batching_rule(batched_args, batch_dims, left_side,
|
||||
def triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
|
||||
lower, transpose_a, conjugate_a,
|
||||
unit_diagonal):
|
||||
x, y = batched_args
|
||||
@ -1206,7 +1262,7 @@ def lu_solve(lu, permutation, b, trans=0):
|
||||
|
||||
# QR decomposition
|
||||
|
||||
def qr_impl(operand, full_matrices):
|
||||
def _qr_impl(operand, *, full_matrices):
|
||||
q, r = xla.apply_primitive(qr_p, operand, full_matrices=full_matrices)
|
||||
return q, r
|
||||
|
||||
@ -1219,7 +1275,7 @@ def _qr_translation_rule(ctx, avals_in, avals_out, operand, *, full_matrices):
|
||||
_zeros_like_xla(ctx.builder, avals_out[1])]
|
||||
return xops.QR(operand, full_matrices)
|
||||
|
||||
def qr_abstract_eval(operand, full_matrices):
|
||||
def _qr_abstract_eval(operand, *, full_matrices):
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
@ -1232,7 +1288,7 @@ def qr_abstract_eval(operand, full_matrices):
|
||||
r = operand
|
||||
return q, r
|
||||
|
||||
def qr_jvp_rule(primals, tangents, full_matrices):
|
||||
def qr_jvp_rule(primals, tangents, *, full_matrices):
|
||||
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
|
||||
x, = primals
|
||||
dx, = tangents
|
||||
@ -1252,7 +1308,7 @@ def qr_jvp_rule(primals, tangents, full_matrices):
|
||||
dr = jnp.matmul(qt_dx_rinv - do, r)
|
||||
return (q, r), (dq, dr)
|
||||
|
||||
def qr_batching_rule(batched_args, batch_dims, full_matrices):
|
||||
def _qr_batching_rule(batched_args, batch_dims, *, full_matrices):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
@ -1329,11 +1385,11 @@ def _qr_cpu_gpu_lowering(geqrf_impl, orgqr_impl, ctx, operand, *,
|
||||
|
||||
qr_p = Primitive('qr')
|
||||
qr_p.multiple_results = True
|
||||
qr_p.def_impl(qr_impl)
|
||||
qr_p.def_abstract_eval(qr_abstract_eval)
|
||||
qr_p.def_impl(_qr_impl)
|
||||
qr_p.def_abstract_eval(_qr_abstract_eval)
|
||||
xla.register_translation(qr_p, _qr_translation_rule)
|
||||
ad.primitive_jvps[qr_p] = qr_jvp_rule
|
||||
batching.primitive_batchers[qr_p] = qr_batching_rule
|
||||
batching.primitive_batchers[qr_p] = _qr_batching_rule
|
||||
|
||||
mlir.register_lowering(
|
||||
qr_p, partial(_qr_cpu_gpu_lowering, lapack.geqrf_mhlo, lapack.orgqr_mhlo),
|
||||
@ -1360,7 +1416,7 @@ if solver_apis is not None:
|
||||
|
||||
# Singular value decomposition
|
||||
|
||||
def svd_impl(operand, full_matrices, compute_uv):
|
||||
def _svd_impl(operand, *, full_matrices, compute_uv):
|
||||
return xla.apply_primitive(svd_p, operand, full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
|
||||
@ -1376,7 +1432,7 @@ def _eye_like_xla(c, aval):
|
||||
xops.Iota(c, iota_shape, len(aval.shape) - 2))
|
||||
return xops.ConvertElementType(x, xla.dtype_to_primitive_type(aval.dtype))
|
||||
|
||||
def svd_abstract_eval(operand, full_matrices, compute_uv):
|
||||
def _svd_abstract_eval(operand, *, full_matrices, compute_uv):
|
||||
if isinstance(operand, ShapedArray):
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to singular value decomposition must have ndims >= 2")
|
||||
@ -1395,7 +1451,7 @@ def svd_abstract_eval(operand, full_matrices, compute_uv):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
|
||||
def _svd_jvp_rule(primals, tangents, *, full_matrices, compute_uv):
|
||||
A, = primals
|
||||
dA, = tangents
|
||||
s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)
|
||||
@ -1520,7 +1576,7 @@ def _svd_tpu_lowering_rule(ctx, operand, *, full_matrices, compute_uv):
|
||||
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
|
||||
ctx, operand, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
||||
def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv):
|
||||
x, = batched_args
|
||||
bd, = batch_dims
|
||||
x = batching.moveaxis(x, bd, 0)
|
||||
@ -1533,10 +1589,10 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
|
||||
|
||||
svd_p = Primitive('svd')
|
||||
svd_p.multiple_results = True
|
||||
svd_p.def_impl(svd_impl)
|
||||
svd_p.def_abstract_eval(svd_abstract_eval)
|
||||
ad.primitive_jvps[svd_p] = svd_jvp_rule
|
||||
batching.primitive_batchers[svd_p] = svd_batching_rule
|
||||
svd_p.def_impl(_svd_impl)
|
||||
svd_p.def_abstract_eval(_svd_abstract_eval)
|
||||
ad.primitive_jvps[svd_p] = _svd_jvp_rule
|
||||
batching.primitive_batchers[svd_p] = _svd_batching_rule
|
||||
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
|
||||
@ -1665,7 +1721,8 @@ def tridiagonal_solve(dl, d, du, b):
|
||||
# Schur Decomposition
|
||||
|
||||
|
||||
def schur(x,
|
||||
@_warn_on_positional_kwargs
|
||||
def schur(x, *,
|
||||
compute_schur_vectors=True,
|
||||
sort_eig_vals=False,
|
||||
select_callable=None):
|
||||
|
@ -61,7 +61,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True,
|
||||
else:
|
||||
return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim-1])
|
||||
|
||||
return lax_linalg.svd(a, full_matrices, compute_uv)
|
||||
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
|
||||
@_wraps(np.linalg.matrix_power)
|
||||
@ -484,7 +484,7 @@ def qr(a, mode="reduced"):
|
||||
else:
|
||||
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
q, r = lax_linalg.qr(a, full_matrices)
|
||||
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
|
||||
if mode == "r":
|
||||
return r
|
||||
return q, r
|
||||
|
@ -72,7 +72,7 @@ def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
|
||||
@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
|
||||
def _svd(a, *, full_matrices, compute_uv):
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
return lax_linalg.svd(a, full_matrices, compute_uv)
|
||||
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
@_wraps(scipy.linalg.svd,
|
||||
lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver'))
|
||||
@ -189,7 +189,7 @@ def _qr(a, mode, pivoting):
|
||||
else:
|
||||
raise ValueError("Unsupported QR decomposition mode '{}'".format(mode))
|
||||
a, = _promote_dtypes_inexact(jnp.asarray(a))
|
||||
q, r = lax_linalg.qr(a, full_matrices)
|
||||
q, r = lax_linalg.qr(a, full_matrices=full_matrices)
|
||||
if mode == "r":
|
||||
return (r,)
|
||||
return q, r
|
||||
|
@ -249,8 +249,9 @@ class NumpyLinalgTest(jtu.JaxTestCase):
|
||||
check_right_eigenvectors(aH, wC, vl)
|
||||
|
||||
a, = args_maker()
|
||||
results = lax.linalg.eig(a, compute_left_eigenvectors,
|
||||
compute_right_eigenvectors)
|
||||
results = lax.linalg.eig(
|
||||
a, compute_left_eigenvectors=compute_left_eigenvectors,
|
||||
compute_right_eigenvectors=compute_right_eigenvectors)
|
||||
w = results[0]
|
||||
|
||||
if compute_left_eigenvectors:
|
||||
|
Loading…
x
Reference in New Issue
Block a user