mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #25007 from jakevdp:deps
PiperOrigin-RevId: 698413340
This commit is contained in:
commit
800add2a03
@ -65,6 +65,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
result in an indexing overflow for batch sizes close to int32 max. See
|
||||
{jax-issue}`#24843` for more details.
|
||||
|
||||
* Deprecations
|
||||
* `jax.lib.xla_extension.ArrayImpl` and `jax.lib.xla_client.ArrayImpl` are deprecated;
|
||||
use `jax.Array` instead.
|
||||
* `jax.lib.xla_extension.XlaRuntimeError` is deprecated; use `jax.errors.JaxRuntimeError`
|
||||
instead.
|
||||
|
||||
## jax 0.4.35 (Oct 22, 2024)
|
||||
|
||||
* Breaking Changes
|
||||
|
@ -18,7 +18,6 @@ from jax._src.lib import xla_client as _xc
|
||||
get_topology_for_devices = _xc.get_topology_for_devices
|
||||
heap_profile = _xc.heap_profile
|
||||
mlir_api_version = _xc.mlir_api_version
|
||||
ArrayImpl = _xc.ArrayImpl
|
||||
Client = _xc.Client
|
||||
CompileOptions = _xc.CompileOptions
|
||||
DeviceAssignment = _xc.DeviceAssignment
|
||||
@ -95,6 +94,11 @@ _deprecations = {
|
||||
"XlaComputation is deprecated; use StableHLO instead.",
|
||||
_xc.XlaComputation,
|
||||
),
|
||||
# Added Nov 20 2024
|
||||
"ArrayImpl": (
|
||||
"jax.lib.xla_client.ArrayImpl is deprecated; use jax.Array instead.",
|
||||
_xc.ArrayImpl,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
@ -106,6 +110,7 @@ if _typing.TYPE_CHECKING:
|
||||
ops = _xc.ops
|
||||
register_custom_call_target = _xc.register_custom_call_target
|
||||
shape_from_pyval = _xc.shape_from_pyval
|
||||
ArrayImpl = _xc.ArrayImpl
|
||||
Device = _xc.Device
|
||||
FftType = _FftType
|
||||
PaddingType = _xc.PaddingType
|
||||
|
@ -24,7 +24,6 @@ mlir = _xe.mlir
|
||||
pmap_lib = _xe.pmap_lib
|
||||
profiler = _xe.profiler
|
||||
pytree = _xe.pytree
|
||||
ArrayImpl = _xe.ArrayImpl
|
||||
Device = _xe.Device
|
||||
DistributedRuntimeClient = _xe.DistributedRuntimeClient
|
||||
HloModule = _xe.HloModule
|
||||
@ -33,6 +32,28 @@ OpSharding = _xe.OpSharding
|
||||
PjitFunctionCache = _xe.PjitFunctionCache
|
||||
PjitFunction = _xe.PjitFunction
|
||||
PmapFunction = _xe.PmapFunction
|
||||
XlaRuntimeError = _xe.XlaRuntimeError
|
||||
|
||||
_deprecations = {
|
||||
# Added Nov 20 2024
|
||||
"ArrayImpl": (
|
||||
"jax.lib.xla_extension.ArrayImpl is deprecated; use jax.Array instead.",
|
||||
_xe.ArrayImpl,
|
||||
),
|
||||
"XlaRuntimeError": (
|
||||
"jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError instead.",
|
||||
_xe.XlaRuntimeError,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
|
||||
if _typing.TYPE_CHECKING:
|
||||
ArrayImpl = _xe.ArrayImpl
|
||||
XlaRuntimeError = _xe.XlaRuntimeError
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
del _xe
|
||||
|
Loading…
x
Reference in New Issue
Block a user