Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.
Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.
PiperOrigin-RevId: 457460347
It turns out that the support for C++17 is partial in 10.12, and in particular absl::optional and std::optional are not the same thing under 10.12. Increment to 10.14 which is the lowest version that builds successfully with absl::optional == std::optional.
See: 89cdaed655/absl/base/config.h (L528)
Strictly speaking, we could allow 10.13, but not without updating ABSL in the TF repository to incorporate c86347d4ce which fixes the version detection test to permit 10.13 as well.
This commit bumps the minimum macOS version target to 10.12, from 10.9
previously. The reason is that a newer LLVM version, which the XLA
compiler depended on, used `std::shared_mutex`, which is only available
starting at macOS 10.12.
The fix here consists of setting the minimum version target flag that
bazel consumes for the build to 10.12.
Otherwise, `copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cpu_feature_guard.so"))`
results an `AttributeError: 'NoneType' object has no attribute 'endswith'`
It seems that *.so file is not symlinked into sandbox anymore due to
pybind11 rules update.
Refactoring only, no functional changes intended.
This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.
Fixes https://github.com/google/jax/issues/10474
PiperOrigin-RevId: 447076192
Previously we depended on various .so files directly so they were pulled into the jaxlib wheel build, but it seems to work to add the libraries in question to //jaxlib and depend on that in the usual way.
It appears if a py_library() is used as a data-dependency of another rule, Bazel includes any transitive C++ extension deps, and that's what we want.
PiperOrigin-RevId: 446802592
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.
This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.
PiperOrigin-RevId: 446761784
See: https://github.com/tensorflow/tensorflow/pull/55613
For a CUDA build at head with the default compute capabilities, reduces wheel size from 141MB to 112MB.
Don't redundantly specify default compute capabilities in .bazelrc and in
build.py.
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.
PiperOrigin-RevId: 443438043
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 428590081
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; s...
PiperOrigin-RevId: 427576107
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.
The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).
The transfer guard can take an action based on its guard level:
* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.
The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:
* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.
Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)
print(x) # No error
with jax.transfer_guard("disallow"):
print(x) # No error; x is already fetched
print(jax.device_get(y)) # No error
print(z) # Error!
```
PiperOrigin-RevId: 427562278
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
This means that jax and its dependencies (e.g. jaxlib) must be
manually installed before running the tests. This is useful for
testing an existing jax install, e.g. a later version of jaxlib, GPU
jaxlib, etc.