rocm_jax/jax/interpreters
George Necula 20be478a6e [host_callback] Add support for pmap and for passing the device to tap
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
  `id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap

Issue: #5134
Fixes: #5169
2020-12-15 10:46:23 +02:00
..