12 Commits

Author SHA1 Message Date
Yash Katariya
6072d5993e Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Yash Katariya
a37e2159b3 Don't drop out of C++ fast path if mesh pointers are not equal.
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.

PiperOrigin-RevId: 560828993
2023-08-28 15:04:05 -07:00
Yash Katariya
242c2c1b52 Use _internal_device_list in __hash__ and __eq__ of Shardings and Mesh to speed them up.
PiperOrigin-RevId: 557665385
2023-08-16 18:41:58 -07:00
Hyeontaek Lim
423c8d8d4f [JAX] Use DeviceList in JAX Sharding implementations
XLA-compatible `Sharding` implementations keep a `DeviceList` object as
`_internal_device_list`. This is used for finding the default memory kind more
quickly in C++, and enables caching of the default memory kind between multiple
`NamedSharding` objects that shares the same `Mesh`. Also it uses an
addressable device within `DeviceList`, which will be required for supporting
multiple device types with different default memory kinds.

PiperOrigin-RevId: 556969789
2023-08-14 18:11:23 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
fdbad53b15 Make _device_assignment a Tuple[Device] so that we don't convert a list to a tuple and vice-versa everywhere
PiperOrigin-RevId: 524002310
2023-04-13 08:03:27 -07:00
Yash Katariya
a3ce08cf1d Override addressable_devices for NamedSharding since the mesh can be the same throughout the program.
PiperOrigin-RevId: 522677209
2023-04-07 13:54:37 -07:00
Peter Hawkins
0f368e4428 Cache __repr__ and device_ids properties on Mesh.
PiperOrigin-RevId: 522653188
2023-04-07 12:12:14 -07:00
Yash Katariya
8838039287 Override is_fully_addressable() for NamedSharding.
The intent of this change is to speed up is_fully_addressable() when computing it repeatedly over the same mesh.

PiperOrigin-RevId: 522500766
2023-04-06 19:46:29 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00