Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 428590081
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; s...
PiperOrigin-RevId: 427576107
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 427562278
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
This means that jax and its dependencies (e.g. jaxlib) must be
manually installed before running the tests. This is useful for
testing an existing jax install, e.g. a later version of jaxlib, GPU
jaxlib, etc.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.
This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.
When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.
Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.
PiperOrigin-RevId: 394705605
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.
The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.
There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.
PiperOrigin-RevId: 394460343
PEP-561 does not specify whether subpackages of a non-stub-only-package
could use the -stubs suffix. setuptools seems to allow that, yet mypy fails
to resolve the subpackage with a -stubs suffix.
This commit makes jaxlib.xla_extension a ~normal package with a toplevel
__init__.pyi.
Previously, the libtpu-nightly wheels were included in the same index
file as the jaxlib wheels (jax_releases.html). This caused issues
because it would cause `pip install jax[tpu] -f jaxlib_releases.html`
to install a cuda jaxlib, instead of the regular CPU/TPU jaxlib from
pypi.
Instead, we create a separate index file for the libtpu-nightly
wheels, so `pip install jax[tpu] -f libtpu_releases.html` still uses
the jaxlib from pypi.
This also renames generate_release_index.py to generate_release_indexes.py.