mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Clean up a number of finalized deprecations
This commit is contained in:
parent
2b86f38585
commit
91a07ea2e8
@ -220,11 +220,6 @@ _deprecations = {
|
||||
"or jax.tree_util.tree_map (any JAX version).",
|
||||
_deprecated_tree_map
|
||||
),
|
||||
# Finalized Nov 12 2024; remove after Feb 12 2025
|
||||
"clear_backends": (
|
||||
"jax.clear_backends was removed in JAX v0.4.36",
|
||||
None
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
|
@ -1944,8 +1944,7 @@ def isrealobj(x: Any) -> bool:
|
||||
|
||||
@export
|
||||
def reshape(
|
||||
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
|
||||
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(),
|
||||
a: ArrayLike, shape: DimSize | Shape, order: str = "C", *,
|
||||
copy: bool | None = None) -> Array:
|
||||
"""Return a reshaped copy of an array.
|
||||
|
||||
@ -1962,8 +1961,6 @@ def reshape(
|
||||
JAX does not support ``order="A"``.
|
||||
copy: unused by JAX; JAX always returns a copy, though under JIT the compiler
|
||||
may optimize such copies away.
|
||||
newshape: deprecated alias of the ``shape`` argument. Will result in a
|
||||
:class:`DeprecationWarning` if used.
|
||||
|
||||
Returns:
|
||||
reshaped copy of input array with the specified shape.
|
||||
@ -2021,14 +2018,6 @@ def reshape(
|
||||
__tracebackhide__ = True
|
||||
util.check_arraylike("reshape", a)
|
||||
|
||||
# TODO(jakevdp): finalized 2024-12-2; remove argument after JAX v0.4.40.
|
||||
if not isinstance(newshape, DeprecatedArg):
|
||||
raise TypeError("The newshape argument to jnp.reshape was removed in JAX v0.4.36."
|
||||
" Use shape instead.")
|
||||
if shape is None:
|
||||
raise TypeError(
|
||||
"jnp.shape requires passing a `shape` argument, but none was given."
|
||||
)
|
||||
try:
|
||||
# forward to method for ndarrays
|
||||
return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr]
|
||||
|
23
jax/core.py
23
jax/core.py
@ -160,29 +160,6 @@ _deprecations = {
|
||||
_src_core.lattice_join),
|
||||
"raise_to_shaped": ("jax.core.raise_to_shaped is deprecated. It is a no-op as of JAX v0.4.36.",
|
||||
_src_core.raise_to_shaped),
|
||||
# Finalized 2024-12-11; remove after 2025-3-11
|
||||
"check_eqn": ("jax.core.check_eqn was removed in JAX v0.4.38.", None),
|
||||
"check_type": ("jax.core.check_type was removed in JAX v0.4.38.", None),
|
||||
"check_valid_jaxtype": (
|
||||
("jax.core.check_valid_jaxtype was removed in JAX v0.4.38. Instead, you can manually"
|
||||
" raise an error if core.valid_jaxtype() returns False."),
|
||||
None),
|
||||
"non_negative_dim": (
|
||||
"jax.core.non_negative_dim was removed in JAX v0.4.38. Use max_dim(..., 0).", None,
|
||||
),
|
||||
# Finalized 2024-09-25; remove after 2024-12-25
|
||||
"pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None),
|
||||
"pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None),
|
||||
"pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None),
|
||||
"pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None),
|
||||
"pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None),
|
||||
"pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None),
|
||||
"pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None),
|
||||
"pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None),
|
||||
"pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None),
|
||||
"pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None),
|
||||
"pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None),
|
||||
"pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None),
|
||||
}
|
||||
|
||||
import typing
|
||||
|
@ -38,19 +38,6 @@ _deprecations = {
|
||||
"jax.interpreters.xla.pytype_aval_mappings is deprecated.",
|
||||
_src_core.pytype_aval_mappings
|
||||
),
|
||||
# Finalized 2024-10-24; remove after 2025-01-24
|
||||
"xb": (
|
||||
("jax.interpreters.xla.xb was removed in JAX v0.4.36. "
|
||||
"Use jax.lib.xla_bridge instead."), None
|
||||
),
|
||||
"xc": (
|
||||
("jax.interpreters.xla.xc was removed in JAX v0.4.36. "
|
||||
"Use jax.lib.xla_client instead."), None
|
||||
),
|
||||
"xe": (
|
||||
("jax.interpreters.xla.xe was removed in JAX v0.4.36. "
|
||||
"Use jax.lib.xla_extension instead."), None
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
|
@ -27,15 +27,6 @@ _deprecations = {
|
||||
"jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.",
|
||||
_deprecated_get_backend
|
||||
),
|
||||
# Finalized 2024-12-11; remove after 2025-3-11
|
||||
"xla_client": (
|
||||
"jax.lib.xla_bridge.xla_client was removed in JAX v0.4.38; use jax.lib.xla_client directly.",
|
||||
None
|
||||
),
|
||||
"default_backend": (
|
||||
"jax.lib.xla_bridge.default_backend was removed in JAX v0.4.38; use jax.default_backend.",
|
||||
None
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
|
@ -26,27 +26,6 @@ OpSharding = _xc.OpSharding
|
||||
Traceback = _xc.Traceback
|
||||
|
||||
_deprecations = {
|
||||
# Finalized 2024-12-11; remove after 2025-3-11
|
||||
"_xla": (
|
||||
"jax.lib.xla_client._xla was removed in JAX v0.4.38; use jax.lib.xla_extension.",
|
||||
None,
|
||||
),
|
||||
"bfloat16": (
|
||||
"jax.lib.xla_client.bfloat16 was removed in JAX v0.4.38; use ml_dtypes.bfloat16.",
|
||||
None,
|
||||
),
|
||||
# Finalized 2024-12-23; remove after 2024-03-23
|
||||
"Device": (
|
||||
"jax.lib.xla_client.Device is deprecated; use jax.Device instead.",
|
||||
None,
|
||||
),
|
||||
"XlaRuntimeError": (
|
||||
(
|
||||
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
|
||||
" jax.errors.JaxRuntimeError."
|
||||
),
|
||||
None,
|
||||
),
|
||||
# Finalized 2025-03-25; remove after 2025-06-25
|
||||
"FftType": (
|
||||
"jax.lib.xla_client.FftType was removed in JAX v0.6.0; use jax.lax.FftType.",
|
||||
@ -106,12 +85,10 @@ if _typing.TYPE_CHECKING:
|
||||
ops = _xc.ops
|
||||
register_custom_call_target = _xc.register_custom_call_target
|
||||
ArrayImpl = _xc.ArrayImpl
|
||||
Device = _xc.Device
|
||||
PrimitiveType = _xc.PrimitiveType
|
||||
Shape = _xc.Shape
|
||||
XlaBuilder = _xc.XlaBuilder
|
||||
XlaComputation = _xc.XlaComputation
|
||||
XlaRuntimeError = _xc.XlaRuntimeError
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
|
||||
|
@ -506,19 +506,3 @@ from jax._src.numpy.vectorize import vectorize as vectorize
|
||||
from jax._src.numpy.array_methods import register_jax_array_methods
|
||||
register_jax_array_methods()
|
||||
del register_jax_array_methods
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Finalized 2024-12-13; remove after 2024-3-13
|
||||
"round_": (
|
||||
"jnp.round_ was deprecated in JAX 0.4.38; use jnp.round instead.",
|
||||
None
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if not typing.TYPE_CHECKING:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
|
@ -808,8 +808,7 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ...
|
||||
def repeat(a: ArrayLike, repeats: ArrayLike, axis: int | None = ..., *,
|
||||
total_repeat_length: int | None = ...) -> Array: ...
|
||||
def reshape(
|
||||
a: ArrayLike, shape: DimSize | Shape = ...,
|
||||
newshape: DimSize | Shape | None = ..., order: str = ...
|
||||
a: ArrayLike, shape: DimSize | Shape, order: str = ..., *, copy: bool | None = ...,
|
||||
) -> Array: ...
|
||||
|
||||
def resize(a: ArrayLike, new_shape: Shape) -> Array: ...
|
||||
|
@ -34,18 +34,3 @@ from jax._src.mesh import (
|
||||
AxisType as AxisType,
|
||||
get_abstract_mesh as get_abstract_mesh,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Finalized 2024-10-01; remove after 2025-01-01.
|
||||
"XLACompatibleSharding": (
|
||||
(
|
||||
"jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. "
|
||||
"Use jax.sharding.Sharding instead."
|
||||
),
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
|
@ -3496,11 +3496,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
def testReshapeDeprecatedArgs(self):
|
||||
msg = "The newshape argument to jnp.reshape was removed in JAX v0.4.36."
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
jnp.reshape(jnp.arange(4), newshape=(2, 2))
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, out_shape=out_shape)
|
||||
for arg_shape, out_shape in [
|
||||
|
Loading…
x
Reference in New Issue
Block a user