6 Commits

Author SHA1 Message Date
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
fdbad53b15 Make _device_assignment a Tuple[Device] so that we don't convert a list to a tuple and vice-versa everywhere
PiperOrigin-RevId: 524002310
2023-04-13 08:03:27 -07:00
Yash Katariya
a3ce08cf1d Override addressable_devices for NamedSharding since the mesh can be the same throughout the program.
PiperOrigin-RevId: 522677209
2023-04-07 13:54:37 -07:00
Peter Hawkins
0f368e4428 Cache __repr__ and device_ids properties on Mesh.
PiperOrigin-RevId: 522653188
2023-04-07 12:12:14 -07:00
Yash Katariya
8838039287 Override is_fully_addressable() for NamedSharding.
The intent of this change is to speed up is_fully_addressable() when computing it repeatedly over the same mesh.

PiperOrigin-RevId: 522500766
2023-04-06 19:46:29 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00