For example, tree_map(..., None, [2, 3]) did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.
PiperOrigin-RevId: 674019460
This changes:
* The performance of array[0] (use array[0:1] instead).
* The shape of jax_array.addressable_shards or jax_array.addressable_data(0) of arrays that come from pmap.
PiperOrigin-RevId: 673564995
A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity
PiperOrigin-RevId: 669244669
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.
[Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)
[Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)
[Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)
2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.
Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.
```
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
"cuda_json_init_repository",
)
cuda_json_init_repository()
load(
"@cuda_redist_json//:distributions.bzl",
"CUDA_REDISTRIBUTIONS",
"CUDNN_REDISTRIBUTIONS",
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
"cuda_redist_init_repositories",
"cudnn_redist_init_repository",
)
cuda_redist_init_repositories(
cuda_redistributions = CUDA_REDISTRIBUTIONS,
)
cudnn_redist_init_repository(
cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
)
load(
"@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
"cuda_configure",
)
cuda_configure(name = "local_config_cuda")
load(
"@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
"nccl_redist_init_repository",
)
nccl_redist_init_repository()
load(
"@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
"nccl_configure",
)
nccl_configure(name = "local_config_nccl")
```
PiperOrigin-RevId: 662981325
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.
PiperOrigin-RevId: 662142636
* Add the source location information for the index map function to
`BlockMapping`.
* Removed the `compute_index` wrapper around the index_map, so that
we can get the location information for the index_map, not the wrapper.
* Added source location to the errors related to index map functions.
* Added an error if the index map returns something other than integer
scalars.
* Construct BlockSpec origins for arguments using JAX helper functions
to get argument names
* Removed redundant API error tests from tpu_pallas_test.py
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.
Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.
PiperOrigin-RevId: 657225917
The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.
Fixes https://github.com/google/jax/issues/3589
Fixes https://github.com/google/jax/issues/15429
PiperOrigin-RevId: 653562611
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.
PiperOrigin-RevId: 652749891
https://github.com/google/jax/pull/22211 forbade custom lowering rules from returning singleton tuples of ir.Value, but this appears to break downstream users, notably Transformer Engine. Instead, allow lowering rules to return singleton tuples and unwrap them if needed, but warn if this behavior is seen.
PiperOrigin-RevId: 650345051
So, instead of
pl.BlockSpec(lambda i, j: ..., (42, 24))
``pl.BlockSpec`` now expects
pl.BlockSpec((42, 24), lambda i, j: ...)
I will update Pallas tests in a follow up.
PiperOrigin-RevId: 648486321
- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`
These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.
PiperOrigin-RevId: 647671122
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.
PiperOrigin-RevId: 647328834
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
- For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case