mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 14:46:06 +00:00

Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array. Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays). * Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`. PiperOrigin-RevId: 485075693