mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00

Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers. The API distinguishes between two types of transfers: * explicit transfers: `jax.device_put*()` and `jax.device_get()` calls. * implicit transfers: Other transfers (e.g., printing a `DeviceArray`). The transfer guard can take an action based on its guard level: * "allow": Silently allow all transfers (default; s... PiperOrigin-RevId: 427576107