rocm_jax/jax/interpreters
Yash Katariya 83d7e3f085 Remove the fallback to lower_xla_callable that exist for jit(pmap) cases when Array was enabled and add minimal support to lower_sharding_computation.
The `jit(pmap)` codepath is added to `lower_sharding_computation` to delete the `lower_xla_callable` codepath when `jax.Array` is enabled by default. This will help in cleaning up the codebase and get rid of tech debt.

* Round trip through host for `Array`'s that have PmapSharding and come through the `jit` path (exactly like SDAs).

* For other cases i.e. when `num_replicas > 1`, default to the `_execute_replicated` path in dispatch.py from `lower_sharding_computation`. This is exactly same to what happens in `lower_xla_callable`.

PiperOrigin-RevId: 471033420
2022-08-30 10:46:23 -07:00
..