diff --git a/CHANGELOG.md b/CHANGELOG.md index 810e23b10..3fb64ef09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * JAX package extras are now updated to use dash instead of underscore to align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]` to install JAX, run `pip install jax[cuda12-local]` instead. + * {func}`jax.jit` now requires `fun` to be passed by position, and additional + arguments to be passed by keyword. Doing otherwise will result in a + DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. * Deprecations diff --git a/jax/_src/api.py b/jax/_src/api.py index 23c0a610c..43ab7729a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -37,6 +37,7 @@ import weakref import numpy as np from contextlib import contextmanager +from jax._src import deprecations from jax._src import linear_util as lu from jax._src import stages from jax._src.tree_util import ( @@ -147,8 +148,39 @@ config.debug_infs._add_hooks(_update_debug_special_global, float0 = dtypes.float0 +# TODO(jakevdp): remove this for v0.7.0 (~July 2025) +def _allow_deprecated_jit_signature(f: F) -> F: + """Temporary decorator for the jit signature deprecation.""" + @wraps(f) + def wrapped(*args, **kwargs): + if len(args) == 1 or deprecations.is_accelerated('jax-jit-positional-args'): + # Fast path for typical usage. + return f(*args, **kwargs) + if 'fun' in kwargs: + deprecations.warn( + 'jax-jit-positional-args', + ('jax.jit: passing fun by keyword is deprecated.' + ' Pass it by position to silence this warning.'), + stacklevel=2 + ) + return f(kwargs.pop('fun'), **kwargs) + if len(args) > 1: + deprecations.warn( + 'jax-jit-positional-args', + ('jax.jit: passing optional arguments by position is deprecated. ' + ' Pass them by keyword to silence this warning.'), + stacklevel=2 + ) + sig = inspect.signature(f) + kwds = dict(unsafe_zip((p.name for p in sig.parameters.values()), args)) + return f(kwds.pop('fun'), **kwds, **kwargs) + return f(*args, **kwargs) + return cast(F, wrapped) + + +@_allow_deprecated_jit_signature def jit( - fun: Callable, + fun: Callable, /, *, in_shardings: Any = sharding_impls.UNSPECIFIED, out_shardings: Any = sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index 6c39c893a..329491b1e 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -134,3 +134,4 @@ register('jax-numpy-quantile-interpolation') register('jax-numpy-reduction-non-boolean-where') register('jax-numpy-trimzeros-not-1d-array') register('jax-scipy-special-sph-harm') +register('jax-jit-positional-args') diff --git a/tests/api_test.py b/tests/api_test.py index 59a6211e3..7590b6e6a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -52,6 +52,7 @@ from jax._src import array from jax._src import config from jax._src import core from jax._src import custom_derivatives +from jax._src import deprecations from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge @@ -10516,6 +10517,19 @@ class CustomTransposeTest(jtu.JaxTestCase): self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) + def test_jit_signature_deprecation(self): + fun = lambda x: x + if deprecations.is_accelerated('jax-jit-positional-args'): + with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'): + jax.jit(fun=fun) + with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'): + jax.jit(fun, None) + else: + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'): + jax.jit(fun=fun) + with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'): + jax.jit(fun, None) + def test_cond(self): def f(x, y): @custom_transpose(jnp.ones(2))