Merge pull request #25007 from jakevdp:deps

PiperOrigin-RevId: 698413340
This commit is contained in:
jax authors 2024-11-20 09:13:05 -08:00
commit 800add2a03
3 changed files with 35 additions and 3 deletions

View File

@ -65,6 +65,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}`#24843` for more details.
* Deprecations
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
use `jax.Array` instead.
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
instead.
## jax 0.4.35 (Oct 22, 2024)
* Breaking Changes

View File

@ -18,7 +18,6 @@ from jax._src.lib import xla_client as _xc
get_topology_for_devices = _xc.get_topology_for_devices
heap_profile = _xc.heap_profile
mlir_api_version = _xc.mlir_api_version
ArrayImpl = _xc.ArrayImpl
Client = _xc.Client
CompileOptions = _xc.CompileOptions
DeviceAssignment = _xc.DeviceAssignment
@ -95,6 +94,11 @@ _deprecations = {
"XlaComputation is deprecated; use StableHLO instead.",
_xc.XlaComputation,
),
# Added Nov 20 2024
"ArrayImpl": (
"jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.",
_xc.ArrayImpl,
),
}
import typing as _typing
@ -106,6 +110,7 @@ if _typing.TYPE_CHECKING:
ops = _xc.ops
register_custom_call_target = _xc.register_custom_call_target
shape_from_pyval = _xc.shape_from_pyval
ArrayImpl = _xc.ArrayImpl
Device = _xc.Device
FftType = _FftType
PaddingType = _xc.PaddingType

View File

@ -24,7 +24,6 @@ mlir = _xe.mlir
pmap_lib = _xe.pmap_lib
profiler = _xe.profiler
pytree = _xe.pytree
ArrayImpl = _xe.ArrayImpl
Device = _xe.Device
DistributedRuntimeClient = _xe.DistributedRuntimeClient
HloModule = _xe.HloModule
@ -33,6 +32,28 @@ OpSharding = _xe.OpSharding
PjitFunctionCache = _xe.PjitFunctionCache
PjitFunction = _xe.PjitFunction
PmapFunction = _xe.PmapFunction
XlaRuntimeError = _xe.XlaRuntimeError
_deprecations = {
# Added Nov 20 2024
"ArrayImpl": (
"jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.",
_xe.ArrayImpl,
),
"XlaRuntimeError": (
"jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.",
_xe.XlaRuntimeError,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
ArrayImpl = _xe.ArrayImpl
XlaRuntimeError = _xe.XlaRuntimeError
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
del _xe