mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 22:06:06 +00:00

The serial_pmap transformation was a placeholder and is now replaced by soft_pmap. The papply tests that used serial_pmap now use soft_pmap, which means they can run on parallel hardware when available. The papply transform had some unused features (e.g. in_axes, out_axes) that won't be needed by parallelize, so those are removed. It is also now only needed for testing now, since parallelize (which essentially composes a soft_pmap with a papply) is likely to be the primary user-facing API. This commit adds the parallelize transformation and some tests for it, including exhaustive transpose tests. Misc changes: * simplified the transpose papply rule and made it lazy (so that it doesn't need to perform communication) * misc bugs encountered * a few lines cherry-picked from frostig@ branch, namely the fixed broadcasting_papply rule and plumbing the `size` argument to papply rules * remove psplit primitive and psplit_like primitives and replace it with calls to all_to_all where needed