1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

jax.lax: deprecate inadvertent exports & internal utilities

This commit is contained in:
Jake VanderPlas 2023-10-06 11:26:03 -07:00
parent 4c37c79270
commit ce6a0c43ad
2 changed files with 82 additions and 10 deletions

@ -15,6 +15,15 @@ Remember to align the itemized text with the first line of an item within a list
If using the recommended `cuda12_pip` installation, NCCL should be installed
automatically.
* Deprecations
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have
been deprecated, and will be removed in a future release.
* `jax.lax.dtypes`: use `jax.dtypes` instead.
* `jax.lax.itertools`: use `itertools` instead.
* `naryop`, `naryop_dtype_rule`, `standard_abstract_eval`, `standard_naryop`,
`standard_primitive`, `standard_unop`, `unop`, and `unop_dtype_rule` are
internal utilities, now deprecated without replacement.
# jax 0.4.17 (Oct 3, 2023)
* New features

@ -88,7 +88,7 @@ from jax._src.lax.lax import (
dot_general as dot_general,
dot_general_p as dot_general_p,
dtype as dtype,
dtypes as dtypes,
dtypes as _deprecated_dtypes,
eq as eq,
eq_p as eq_p,
exp as exp,
@ -116,7 +116,7 @@ from jax._src.lax.lax import (
iota_p as iota_p,
is_finite as is_finite,
is_finite_p as is_finite_p,
itertools as itertools,
itertools as _deprecated_itertools,
le as le,
le_p as le_p,
log as log,
@ -133,8 +133,8 @@ from jax._src.lax.lax import (
min_p as min_p,
mul as mul,
mul_p as mul_p,
naryop as naryop,
naryop_dtype_rule as naryop_dtype_rule,
naryop as _deprecated_naryop,
naryop_dtype_rule as _deprecated_naryop_dtype_rule,
ne as ne,
ne_p as ne_p,
neg as neg,
@ -203,10 +203,10 @@ from jax._src.lax.lax import (
square as square,
squeeze as squeeze,
squeeze_p as squeeze_p,
standard_abstract_eval as standard_abstract_eval,
standard_naryop as standard_naryop,
standard_primitive as standard_primitive,
standard_unop as standard_unop,
standard_abstract_eval as _deprecated_standard_abstract_eval,
standard_naryop as _deprecated_standard_naryop,
standard_primitive as _deprecated_standard_primitive,
standard_unop as _deprecated_standard_unop,
stop_gradient as stop_gradient,
sub as sub,
sub_p as sub_p,
@ -219,8 +219,8 @@ from jax._src.lax.lax import (
top_k_p as top_k_p,
transpose as transpose,
transpose_p as transpose_p,
unop as unop,
unop_dtype_rule as unop_dtype_rule,
unop as _deprecated_unop,
unop_dtype_rule as _deprecated_unop_dtype_rule,
xor_p as xor_p,
zeros_like_array as zeros_like_array,
)
@ -377,3 +377,66 @@ from jax.lax import linalg as linalg
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
from jax._src.dispatch import device_put_p as device_put_p
_deprecations = {
# Added October 6 2023
"dtypes": (
"jax.lax.dtypes is deprecated: import jax.dtypes directly.",
_deprecated_dtypes,
),
"itertools": (
"jax.lax.itertools is deprecated: import itertools directly.",
_deprecated_itertools,
),
"naryop": (
"jax.lax.naryop is an internal API and has been deprecated.",
_deprecated_naryop,
),
"naryop_dtype_rule": (
"jax.lax.naryop_dtype_rule is an internal API and has been deprecated.",
_deprecated_naryop_dtype_rule,
),
"standard_abstract_eval": (
"jax.lax.standard_abstract_eval is an internal API and has been deprecated.",
_deprecated_standard_abstract_eval,
),
"standard_naryop": (
"jax.lax.standard_naryop is an internal API and has been deprecated.",
_deprecated_standard_naryop,
),
"standard_primitive": (
"jax.lax.standard_primitive is an internal API and has been deprecated.",
_deprecated_standard_primitive,
),
"standard_unop": (
"jax.lax.standard_unop is an internal API and has been deprecated.",
_deprecated_standard_unop,
),
"unop": (
"jax.lax.unop is an internal API and has been deprecated.",
_deprecated_unop,
),
"unop_dtype_rule": (
"jax.lax.unop_dtype_rule is an internal API and has been deprecated.",
_deprecated_unop_dtype_rule,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
dtypes = _deprecated_dtypes,
itertools = _deprecated_itertools,
naryop = _deprecated_naryop,
naryop_dtype_rule = _deprecated_naryop_dtype_rule,
standard_abstract_eval = _deprecated_standard_abstract_eval,
standard_naryop = _deprecated_standard_naryop,
standard_primitive = _deprecated_standard_primitive,
standard_unop = _deprecated_standard_unop,
unop = _deprecated_unop,
unop_dtype_rule = _deprecated_unop_dtype_rule,
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing