Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.
In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.
This change does not yet remove any exported names.
Issue https://github.com/google/jax/issues/11951
PiperOrigin-RevId: 469480816
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
There's no point building extensions that cannot be used, and they sometimes cause problems (e.g., TPU support on Windows builds).
Fixes https://github.com/google/jax/issues/10687
PiperOrigin-RevId: 449287997
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.
This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.
PiperOrigin-RevId: 446761784
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.
[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.
PiperOrigin-RevId: 446737518
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.
PiperOrigin-RevId: 443438043
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.
PiperOrigin-RevId: 442547482
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 428590081
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; s...
PiperOrigin-RevId: 427576107
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 427562278