Clean up a number of finalized deprecations

This commit is contained in:
Jake VanderPlas 2025-03-25 08:40:43 -07:00
parent 2b86f38585
commit 91a07ea2e8
10 changed files with 2 additions and 123 deletions

View File

@ -220,11 +220,6 @@ _deprecations = {
"or jax.tree_util.tree_map (any JAX version).",
_deprecated_tree_map
),
# Finalized Nov 12 2024; remove after Feb 12 2025
"clear_backends": (
"jax.clear_backends was removed in JAX v0.4.36",
None
),
}
import typing as _typing

View File

@ -1944,8 +1944,7 @@ def isrealobj(x: Any) -> bool:
@export
def reshape(
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(),
a: ArrayLike, shape: DimSize | Shape, order: str = "C", *,
copy: bool | None = None) -> Array:
"""Return a reshaped copy of an array.
@ -1962,8 +1961,6 @@ def reshape(
JAX does not support ``order="A"``.
copy: unused by JAX; JAX always returns a copy, though under JIT the compiler
may optimize such copies away.
newshape: deprecated alias of the ``shape`` argument. Will result in a
:class:`DeprecationWarning` if used.
Returns:
reshaped copy of input array with the specified shape.
@ -2021,14 +2018,6 @@ def reshape(
__tracebackhide__ = True
util.check_arraylike("reshape", a)
# TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40.
if not isinstance(newshape, DeprecatedArg):
raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36."
" Use shape instead.")
if shape is None:
raise TypeError(
"jnp.shape requires passing a `shape` argument, but none was given."
)
try:
# forward to method for ndarrays
return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr]

View File

@ -160,29 +160,6 @@ _deprecations = {
_src_core.lattice_join),
"raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.",
_src_core.raise_to_shaped),
# Finalized 2024-12-11; remove after 2025-3-11
"check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None),
"check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None),
"check_valid_jaxtype": (
("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually"
" raise an error if core.valid_jaxtype() returns False."),
None),
"non_negative_dim": (
"jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None,
),
# Finalized 2024-09-25; remove after 2024-12-25
"pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None),
"pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None),
"pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None),
"pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None),
"pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None),
"pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None),
"pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None),
"pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None),
"pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None),
"pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None),
"pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None),
"pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None),
}
import typing

View File

@ -38,19 +38,6 @@ _deprecations = {
"jax.interpreters.xla.pytype_aval_mappings is deprecated.",
_src_core.pytype_aval_mappings
),
# Finalized 2024-10-24; remove after 2025-01-24
"xb": (
("jax.interpreters.xla.xb was removed in JAX v0.4.36. "
"Use jax.lib.xla_bridge instead."), None
),
"xc": (
("jax.interpreters.xla.xc was removed in JAX v0.4.36. "
"Use jax.lib.xla_client instead."), None
),
"xe": (
("jax.interpreters.xla.xe was removed in JAX v0.4.36. "
"Use jax.lib.xla_extension instead."), None
),
}
import typing as _typing

View File

@ -27,15 +27,6 @@ _deprecations = {
"jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.",
_deprecated_get_backend
),
# Finalized 2024-12-11; remove after 2025-3-11
"xla_client": (
"jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.",
None
),
"default_backend": (
"jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.",
None
),
}
import typing as _typing

View File

@ -26,27 +26,6 @@ OpSharding = _xc.OpSharding
Traceback = _xc.Traceback
_deprecations = {
# Finalized 2024-12-11; remove after 2025-3-11
"_xla": (
"jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.",
None,
),
"bfloat16": (
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
None,
),
# Finalized 2024-12-23; remove after 2024-03-23
"Device": (
"jax.lib.xla_client.Device is deprecated; use jax.Device instead.",
None,
),
"XlaRuntimeError": (
(
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
" jax.errors.JaxRuntimeError."
),
None,
),
# Finalized 2025-03-25; remove after 2025-06-25
"FftType": (
"jax.lib.xla_client.FftType was removed in JAX v0.6.0; use jax.lax.FftType.",
@ -106,12 +85,10 @@ if _typing.TYPE_CHECKING:
ops = _xc.ops
register_custom_call_target = _xc.register_custom_call_target
ArrayImpl = _xc.ArrayImpl
Device = _xc.Device
PrimitiveType = _xc.PrimitiveType
Shape = _xc.Shape
XlaBuilder = _xc.XlaBuilder
XlaComputation = _xc.XlaComputation
XlaRuntimeError = _xc.XlaRuntimeError
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr

View File

@ -506,19 +506,3 @@ from jax._src.numpy.vectorize import vectorize as vectorize
from jax._src.numpy.array_methods import register_jax_array_methods
register_jax_array_methods()
del register_jax_array_methods
_deprecations = {
# Finalized 2024-12-13; remove after 2024-3-13
"round_": (
"jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.",
None
),
}
import typing
if not typing.TYPE_CHECKING:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing

View File

@ -808,8 +808,7 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *,
total_repeat_length: int | None = ...) -> Array: ...
def reshape(
a: ArrayLike, shape: DimSize | Shape = ...,
newshape: DimSize | Shape | None = ..., order: str = ...
a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ...,
) -> Array: ...
def resize(a: ArrayLike, new_shape: Shape) -> Array: ...

View File

@ -34,18 +34,3 @@ from jax._src.mesh import (
AxisType as AxisType,
get_abstract_mesh as get_abstract_mesh,
)
_deprecations = {
# Finalized 2024-10-01; remove after 2025-01-01.
"XLACompatibleSharding": (
(
"jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. "
"Use jax.sharding.Sharding instead."
),
None,
)
}
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

View File

@ -3496,11 +3496,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
def testReshapeDeprecatedArgs(self):
msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36."
with self.assertRaisesRegex(TypeError, msg):
jnp.reshape(jnp.arange(4), newshape=(2, 2))
@jtu.sample_product(
[dict(arg_shape=arg_shape, out_shape=out_shape)
for arg_shape, out_shape in [