mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 20:06:05 +00:00

This is a follow up from #25640 that enabled lowering with AbstractMesh. This required adding `num_devices` to `lowering.compiler_args` because in presence of an AbstractMesh the device_assignment is not accurate.