Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
This work is an effort to reduce cyclic dependencies in JAX internals.
Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.
PiperOrigin-RevId: 515667671