Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 428590081
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
named_call does not specify donated_invars, this change handles this missing
param case.
For future reference: we might want to add a call_param_updater registry to define
how call params need to get updated wrt checkify, like eg. partial_eval/ad does.
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; s...
PiperOrigin-RevId: 427576107
pp_eqn_compact() is used for one purpose only: creating metadata to put
on HLO. In that case, we don't need such carefully-formatted strings,
and speed is more important.
Gave a 6% speedup on a researcher's model.
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 427562278