mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers. The API distinguishes between two types of transfers: * explicit transfers: `jax.device_put*()` and `jax.device_get()` calls. * implicit transfers: Other transfers (e.g., printing a `DeviceArray`). The transfer guard can take an action based on its guard level: * "allow": Silently allow all transfers (default; same as the previous behavior). * "log": Log and allow implicit transfers. Silently allow explicit transfers. * "disallow": Disallow implicit transfers. Silently allow explicit transfers. * "log_explicit": Log and allow all transfers. * "disallow_explicit": Disallow all transfers. The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction: * "host_to_device": Converting a Python value into a `DeviceBuffer`. * "device_to_device": Copying a `DeviceBuffer` to a different device. * "device_to_host": Fetching the value of a `DeviceBuffer`. Example: ``` x = jnp.array(1) y = jnp.array(2) z = jnp.array(3) print(x) # No error with jax.transfer_guard("disallow"): print(x) # No error; x is already fetched print(jax.device_get(y)) # No error print(z) # Error! ``` PiperOrigin-RevId: 428590081