This is true for mock devices or user specific devices and jax.devices() too.
Fix the tests so that the mock devices are hashable.
PiperOrigin-RevId: 561103167
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.
PiperOrigin-RevId: 560828993
XLA-compatible `Sharding` implementations keep a `DeviceList` object as
`_internal_device_list`. This is used for finding the default memory kind more
quickly in C++, and enables caching of the default memory kind between multiple
`NamedSharding` objects that shares the same `Mesh`. Also it uses an
addressable device within `DeviceList`, which will be required for supporting
multiple device types with different default memory kinds.
PiperOrigin-RevId: 556969789
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
This work is an effort to reduce cyclic dependencies in JAX internals.
Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.
PiperOrigin-RevId: 515667671