This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.
PiperOrigin-RevId: 446391558
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.
(Also there are some unrelated removals of dead code in maps.py.)
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:
make device_array.copy() return a device array
PiperOrigin-RevId: 438308145
Both `Lowered` and `Compiled` carry information about input arguments
for which the underlying computation was lowered (namely avals,
donation bits, and the input pytree structure today). This change
rearranges some internals so that all of this information is held
together in a single pytree of structs. Doing so simplifies the fields
of both stage classes and helps ensure the input argument properties
are consistent with one another (e.g. now they must share a consistent
pytree structure by definition).
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.
Co-authored-by: Matthew Johnson <mattjj@google.com>
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 428590081