mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
jax.jit: deprecate non-standard call signature.
This commit is contained in:
parent
5af5925749
commit
ceca6ec1fc
@ -40,6 +40,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* JAX package extras are now updated to use dash instead of underscore to
|
||||
align with PEP 685. For instance, if you were previously using `pip install jax[cuda12_local]`
|
||||
to install JAX, run `pip install jax[cuda12-local]` instead.
|
||||
* {func}`jax.jit` now requires `fun` to be passed by position, and additional
|
||||
arguments to be passed by keyword. Doing otherwise will result in a
|
||||
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
|
||||
|
||||
* Deprecations
|
||||
|
||||
|
@ -37,6 +37,7 @@ import weakref
|
||||
import numpy as np
|
||||
from contextlib import contextmanager
|
||||
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import stages
|
||||
from jax._src.tree_util import (
|
||||
@ -147,8 +148,39 @@ config.debug_infs._add_hooks(_update_debug_special_global,
|
||||
float0 = dtypes.float0
|
||||
|
||||
|
||||
# TODO(jakevdp): remove this for v0.7.0 (~July 2025)
|
||||
def _allow_deprecated_jit_signature(f: F) -> F:
|
||||
"""Temporary decorator for the jit signature deprecation."""
|
||||
@wraps(f)
|
||||
def wrapped(*args, **kwargs):
|
||||
if len(args) == 1 or deprecations.is_accelerated('jax-jit-positional-args'):
|
||||
# Fast path for typical usage.
|
||||
return f(*args, **kwargs)
|
||||
if 'fun' in kwargs:
|
||||
deprecations.warn(
|
||||
'jax-jit-positional-args',
|
||||
('jax.jit: passing fun by keyword is deprecated.'
|
||||
' Pass it by position to silence this warning.'),
|
||||
stacklevel=2
|
||||
)
|
||||
return f(kwargs.pop('fun'), **kwargs)
|
||||
if len(args) > 1:
|
||||
deprecations.warn(
|
||||
'jax-jit-positional-args',
|
||||
('jax.jit: passing optional arguments by position is deprecated. '
|
||||
' Pass them by keyword to silence this warning.'),
|
||||
stacklevel=2
|
||||
)
|
||||
sig = inspect.signature(f)
|
||||
kwds = dict(unsafe_zip((p.name for p in sig.parameters.values()), args))
|
||||
return f(kwds.pop('fun'), **kwds, **kwargs)
|
||||
return f(*args, **kwargs)
|
||||
return cast(F, wrapped)
|
||||
|
||||
|
||||
@_allow_deprecated_jit_signature
|
||||
def jit(
|
||||
fun: Callable,
|
||||
fun: Callable, /, *,
|
||||
in_shardings: Any = sharding_impls.UNSPECIFIED,
|
||||
out_shardings: Any = sharding_impls.UNSPECIFIED,
|
||||
static_argnums: int | Sequence[int] | None = None,
|
||||
|
@ -134,3 +134,4 @@ register('jax-numpy-quantile-interpolation')
|
||||
register('jax-numpy-reduction-non-boolean-where')
|
||||
register('jax-numpy-trimzeros-not-1d-array')
|
||||
register('jax-scipy-special-sph-harm')
|
||||
register('jax-jit-positional-args')
|
||||
|
@ -52,6 +52,7 @@ from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import deprecations
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -10516,6 +10517,19 @@ class CustomTransposeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f_(x), g_(x))
|
||||
self.assertAllClose(f_t(x), g_t(x))
|
||||
|
||||
def test_jit_signature_deprecation(self):
|
||||
fun = lambda x: x
|
||||
if deprecations.is_accelerated('jax-jit-positional-args'):
|
||||
with self.assertRaisesRegex(TypeError, r'jit\(\) got some positional-only arguments passed as keyword arguments.*'):
|
||||
jax.jit(fun=fun)
|
||||
with self.assertRaisesRegex(TypeError, r'jit\(\) takes 1 positional argument but 2 were given.*'):
|
||||
jax.jit(fun, None)
|
||||
else:
|
||||
with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing fun by keyword is deprecated.*'):
|
||||
jax.jit(fun=fun)
|
||||
with self.assertWarnsRegex(DeprecationWarning, r'jax\.jit: passing optional arguments by position is deprecated.*'):
|
||||
jax.jit(fun, None)
|
||||
|
||||
def test_cond(self):
|
||||
def f(x, y):
|
||||
@custom_transpose(jnp.ones(2))
|
||||
|
Loading…
x
Reference in New Issue
Block a user