jax.jit: deprecate non-standard call signature.

This commit is contained in:
Jake VanderPlas 2025-04-11 11:58:19 -07:00
parent 5af5925749
commit ceca6ec1fc
4 changed files with 51 additions and 1 deletions

View File

@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]`
to install JAX, run `pip install jax[cuda12-local]` instead.
* {func}`jax.jit` now requires `fun` to be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
* Deprecations

View File

@ -37,6 +37,7 @@ import weakref
import numpy as np
from contextlib import contextmanager
from jax._src import deprecations
from jax._src import linear_util as lu
from jax._src import stages
from jax._src.tree_util import (
@ -147,8 +148,39 @@ config.debug_infs._add_hooks(_update_debug_special_global,
float0 = dtypes.float0
# TODO(jakevdp): remove this for v0.7.0 (~July 2025)
def _allow_deprecated_jit_signature(f: F) -> F:
"""Temporary decorator for the jit signature deprecation."""
@wraps(f)
def wrapped(*args, **kwargs):
if len(args) == 1 or deprecations.is_accelerated('jax-jit-positional-args'):
# Fast path for typical usage.
return f(*args, **kwargs)
if 'fun' in kwargs:
deprecations.warn(
'jax-jit-positional-args',
('jax.jit: passing fun by keyword is deprecated.'
' Pass it by position to silence this warning.'),
stacklevel=2
)
return f(kwargs.pop('fun'), **kwargs)
if len(args) > 1:
deprecations.warn(
'jax-jit-positional-args',
('jax.jit: passing optional arguments by position is deprecated. '
' Pass them by keyword to silence this warning.'),
stacklevel=2
)
sig = inspect.signature(f)
kwds = dict(unsafe_zip((p.name for p in sig.parameters.values()), args))
return f(kwds.pop('fun'), **kwds, **kwargs)
return f(*args, **kwargs)
return cast(F, wrapped)
@_allow_deprecated_jit_signature
def jit(
fun: Callable,
fun: Callable, /, *,
in_shardings: Any = sharding_impls.UNSPECIFIED,
out_shardings: Any = sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,

View File

@ -134,3 +134,4 @@ register('jax-numpy-quantile-interpolation')
register('jax-numpy-reduction-non-boolean-where')
register('jax-numpy-trimzeros-not-1d-array')
register('jax-scipy-special-sph-harm')
register('jax-jit-positional-args')

View File

@ -52,6 +52,7 @@ from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
from jax._src import deprecations
from jax._src import linear_util as lu
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -10516,6 +10517,19 @@ class CustomTransposeTest(jtu.JaxTestCase):
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))
def test_jit_signature_deprecation(self):
fun = lambda x: x
if deprecations.is_accelerated('jax-jit-positional-args'):
with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'):
jax.jit(fun=fun)
with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'):
jax.jit(fun, None)
else:
with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'):
jax.jit(fun=fun)
with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'):
jax.jit(fun, None)
def test_cond(self):
def f(x, y):
@custom_transpose(jnp.ones(2))