mirror of
https://github.com/ROCm/jax.git
synced 2025-04-27 16:06:05 +00:00

* Adds support for jit of pmap and pmap of pmap. * Also adds a `tap_with_device` optional argument to `id_print` and `id_tap`, to have the tap function invoked with a device keyword argument. * Added multiple tests involving pmap Issue: #5134 Fixes: #5169