rocm_jax/jax/interpreters
Yash Katariya ca1f58e37b Add a new jax.spmd_mode config for preventing unintentional hangs and incorrect results when users pass jax.Arrays that span across multiple processes (i.e. not fully addressable) to jit or jnp operations (that are jitted by default).
Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.

Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.

PiperOrigin-RevId: 485075693
2022-10-31 09:51:42 -07:00
..