mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 13:26:06 +00:00
jax.lax: deprecate inadvertent exports & internal utilities
This commit is contained in:
parent
4c37c79270
commit
ce6a0c43ad
@ -15,6 +15,15 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
If using the recommended `cuda12_pip` installation, NCCL should be installed
|
||||
automatically.
|
||||
|
||||
* Deprecations
|
||||
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have
|
||||
been deprecated, and will be removed in a future release.
|
||||
* `jax.lax.dtypes`: use `jax.dtypes` instead.
|
||||
* `jax.lax.itertools`: use `itertools` instead.
|
||||
* `naryop`, `naryop_dtype_rule`, `standard_abstract_eval`, `standard_naryop`,
|
||||
`standard_primitive`, `standard_unop`, `unop`, and `unop_dtype_rule` are
|
||||
internal utilities, now deprecated without replacement.
|
||||
|
||||
# jax 0.4.17 (Oct 3, 2023)
|
||||
|
||||
* New features
|
||||
|
@ -88,7 +88,7 @@ from jax._src.lax.lax import (
|
||||
dot_general as dot_general,
|
||||
dot_general_p as dot_general_p,
|
||||
dtype as dtype,
|
||||
dtypes as dtypes,
|
||||
dtypes as _deprecated_dtypes,
|
||||
eq as eq,
|
||||
eq_p as eq_p,
|
||||
exp as exp,
|
||||
@ -116,7 +116,7 @@ from jax._src.lax.lax import (
|
||||
iota_p as iota_p,
|
||||
is_finite as is_finite,
|
||||
is_finite_p as is_finite_p,
|
||||
itertools as itertools,
|
||||
itertools as _deprecated_itertools,
|
||||
le as le,
|
||||
le_p as le_p,
|
||||
log as log,
|
||||
@ -133,8 +133,8 @@ from jax._src.lax.lax import (
|
||||
min_p as min_p,
|
||||
mul as mul,
|
||||
mul_p as mul_p,
|
||||
naryop as naryop,
|
||||
naryop_dtype_rule as naryop_dtype_rule,
|
||||
naryop as _deprecated_naryop,
|
||||
naryop_dtype_rule as _deprecated_naryop_dtype_rule,
|
||||
ne as ne,
|
||||
ne_p as ne_p,
|
||||
neg as neg,
|
||||
@ -203,10 +203,10 @@ from jax._src.lax.lax import (
|
||||
square as square,
|
||||
squeeze as squeeze,
|
||||
squeeze_p as squeeze_p,
|
||||
standard_abstract_eval as standard_abstract_eval,
|
||||
standard_naryop as standard_naryop,
|
||||
standard_primitive as standard_primitive,
|
||||
standard_unop as standard_unop,
|
||||
standard_abstract_eval as _deprecated_standard_abstract_eval,
|
||||
standard_naryop as _deprecated_standard_naryop,
|
||||
standard_primitive as _deprecated_standard_primitive,
|
||||
standard_unop as _deprecated_standard_unop,
|
||||
stop_gradient as stop_gradient,
|
||||
sub as sub,
|
||||
sub_p as sub_p,
|
||||
@ -219,8 +219,8 @@ from jax._src.lax.lax import (
|
||||
top_k_p as top_k_p,
|
||||
transpose as transpose,
|
||||
transpose_p as transpose_p,
|
||||
unop as unop,
|
||||
unop_dtype_rule as unop_dtype_rule,
|
||||
unop as _deprecated_unop,
|
||||
unop_dtype_rule as _deprecated_unop_dtype_rule,
|
||||
xor_p as xor_p,
|
||||
zeros_like_array as zeros_like_array,
|
||||
)
|
||||
@ -377,3 +377,66 @@ from jax.lax import linalg as linalg
|
||||
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
|
||||
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
|
||||
from jax._src.dispatch import device_put_p as device_put_p
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Added October 6 2023
|
||||
"dtypes": (
|
||||
"jax.lax.dtypes is deprecated: import jax.dtypes directly.",
|
||||
_deprecated_dtypes,
|
||||
),
|
||||
"itertools": (
|
||||
"jax.lax.itertools is deprecated: import itertools directly.",
|
||||
_deprecated_itertools,
|
||||
),
|
||||
"naryop": (
|
||||
"jax.lax.naryop is an internal API and has been deprecated.",
|
||||
_deprecated_naryop,
|
||||
),
|
||||
"naryop_dtype_rule": (
|
||||
"jax.lax.naryop_dtype_rule is an internal API and has been deprecated.",
|
||||
_deprecated_naryop_dtype_rule,
|
||||
),
|
||||
"standard_abstract_eval": (
|
||||
"jax.lax.standard_abstract_eval is an internal API and has been deprecated.",
|
||||
_deprecated_standard_abstract_eval,
|
||||
),
|
||||
"standard_naryop": (
|
||||
"jax.lax.standard_naryop is an internal API and has been deprecated.",
|
||||
_deprecated_standard_naryop,
|
||||
),
|
||||
"standard_primitive": (
|
||||
"jax.lax.standard_primitive is an internal API and has been deprecated.",
|
||||
_deprecated_standard_primitive,
|
||||
),
|
||||
"standard_unop": (
|
||||
"jax.lax.standard_unop is an internal API and has been deprecated.",
|
||||
_deprecated_standard_unop,
|
||||
),
|
||||
"unop": (
|
||||
"jax.lax.unop is an internal API and has been deprecated.",
|
||||
_deprecated_unop,
|
||||
),
|
||||
"unop_dtype_rule": (
|
||||
"jax.lax.unop_dtype_rule is an internal API and has been deprecated.",
|
||||
_deprecated_unop_dtype_rule,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
dtypes = _deprecated_dtypes,
|
||||
itertools = _deprecated_itertools,
|
||||
naryop = _deprecated_naryop,
|
||||
naryop_dtype_rule = _deprecated_naryop_dtype_rule,
|
||||
standard_abstract_eval = _deprecated_standard_abstract_eval,
|
||||
standard_naryop = _deprecated_standard_naryop,
|
||||
standard_primitive = _deprecated_standard_primitive,
|
||||
standard_unop = _deprecated_standard_unop,
|
||||
unop = _deprecated_unop,
|
||||
unop_dtype_rule = _deprecated_unop_dtype_rule,
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
|
Loading…
x
Reference in New Issue
Block a user