mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:06:05 +00:00

The `jit(pmap)` codepath is added to `lower_sharding_computation` to delete the `lower_xla_callable` codepath when `jax.Array` is enabled by default. This will help in cleaning up the codebase and get rid of tech debt. * Round trip through host for `Array`'s that have PmapSharding and come through the `jit` path (exactly like SDAs). * For other cases i.e. when `num_replicas > 1`, default to the `_execute_replicated` path in dispatch.py from `lower_sharding_computation`. This is exactly same to what happens in `lower_xla_callable`. PiperOrigin-RevId: 471033420