660 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
bcbc0962bb Add the FFI functions and tutorial to the changelog.
Although we soft launched the FFI with v0.4.31, it would be nice to
include an update in the changelog to help with visibility.
2024-09-06 12:30:28 -04:00
Peter Hawkins
9c86fdec02 Make optimization_barrier a public lax API. 2024-09-06 00:18:57 +00:00
Sergei Lebedev
1289640f09 Deprecated calling `jax.dlpack.from_dlpack` with a DLPack tensor
PiperOrigin-RevId: 670723176
2024-09-03 15:16:02 -07:00
Jake VanderPlas
f2ffe7f8f2 Deprecate jax.numpy.round_
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Jake VanderPlas
a3d6cf007e First pass at ufunc interfaces for several jax.numpy functions 2024-08-30 11:53:02 -07:00
Sergei Lebedev
02bb884357 `jax.tree_util.register_dataclass now validates data_fields and meta_fields`
A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity

PiperOrigin-RevId: 669244669
2024-08-30 02:01:50 -07:00
rajasekharporeddy
ced012f5ed Update jnp.fabs to emulate the behavior of np.fabs for complex inputs 2024-08-28 20:16:09 +05:30
Bryan Massoth
b38f985b01 Add a callout that LibTPU now supports profiling of SparseCore for TPUv5p chips which will be viewable in Tensorboard Profiler's TraceViewer tool.
PiperOrigin-RevId: 667708094
2024-08-26 14:04:43 -07:00
jax authors
acf4b32452 Merge pull request #23060 from jakevdp:core-deps
PiperOrigin-RevId: 662988768
2024-08-14 11:21:00 -07:00
jax authors
599c13aa09 Introduce hermetic CUDA in Google ML projects.
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.

    [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)

    [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)

    [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)

2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.

   Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.

   ```
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
      "cuda_json_init_repository",
   )

   cuda_json_init_repository()

   load(
      "@cuda_redist_json//:distributions.bzl",
      "CUDA_REDISTRIBUTIONS",
      "CUDNN_REDISTRIBUTIONS",
   )
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
      "cuda_redist_init_repositories",
      "cudnn_redist_init_repository",
   )

   cuda_redist_init_repositories(
      cuda_redistributions = CUDA_REDISTRIBUTIONS,
   )

   cudnn_redist_init_repository(
      cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
   )

   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
      "cuda_configure",
   )

   cuda_configure(name = "local_config_cuda")

   load(
      "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
      "nccl_redist_init_repository",
   )

   nccl_redist_init_repository()

   load(
      "@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
      "nccl_configure",
   )

   nccl_configure(name = "local_config_nccl")
   ```

PiperOrigin-RevId: 662981325
2024-08-14 10:58:43 -07:00
Jake VanderPlas
bd9698ec6d Deprecate several internal utilities in jax.core 2024-08-14 10:06:13 -07:00
Dan Foreman-Mackey
60bf5b7727 Add a jax.process_indices function.
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.

PiperOrigin-RevId: 662142636
2024-08-12 10:30:41 -07:00
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Yue Sheng
f255fb700a Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False), one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 659741822
2024-08-05 17:48:17 -07:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
Jake VanderPlas
3551fcc077 Deprecate several APIs in jax.lib.xla_bridge
PiperOrigin-RevId: 658274719
2024-07-31 23:00:35 -07:00
Jake VanderPlas
8bcd288621 Raise ValueError for complex inputs to jnp.clip and jnp.hypot.
Such inputs were deprecated in JAX v0.4.27, and have been raising a DeprecationWarning for the last several releases.

PiperOrigin-RevId: 657717875
2024-07-30 13:49:37 -07:00
jax authors
256956ad58 Merge pull request #22704 from gnecula:pallas_better_errors
PiperOrigin-RevId: 657571604
2024-07-30 06:39:44 -07:00
Peter Hawkins
c1cd7f9e2d Drop support for mhlo in JAX's public API.
PiperOrigin-RevId: 657551590
2024-07-30 05:29:52 -07:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
Yash Katariya
2106a25977 Finish jax and jaxlib v0.4.31 release
PiperOrigin-RevId: 657388782
2024-07-29 17:57:37 -07:00
Peter Hawkins
d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
Peter Hawkins
fd23b8733d Bump minimum SciPy version to 1.10.
SciPy 1.9.0 was released July 29, 2022, which is 24 months ago

PiperOrigin-RevId: 657215038
2024-07-29 08:50:18 -07:00
Jake VanderPlas
a17c8d945b Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
Peter Hawkins
47e6da3332 Don't mask out zero elements on the diagonal of the matrix when inverting triangular matrices.
The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes https://github.com/google/jax/issues/3589
Fixes https://github.com/google/jax/issues/15429

PiperOrigin-RevId: 653562611
2024-07-18 04:09:44 -07:00
George Necula
d34a6e9ce2 [jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.

PiperOrigin-RevId: 652749891
2024-07-16 02:05:43 -07:00
George Necula
9cd94019b4 [pallas] Added a CHANGELOG for Pallas
The CHANGELOG is populated with the changes since June 10th, when
JAX 0.4.29 was released.
2024-07-12 00:05:31 +03:00
Gleb Pobudzey
46103f6ff3 Updated the repr of GPU devices to be more consistent with TPUs/CPUs.
For example, `cuda(id=0)` will now be `CudaDevice(id=0)`

PiperOrigin-RevId: 651393690
2024-07-11 06:54:20 -07:00
Peter Hawkins
262a4f482c Deprecate support for custom lowering rules that return tuple-wrapped ir.Values.
https://github.com/google/jax/pull/22211 forbade custom lowering rules from returning singleton tuples of ir.Value, but this appears to break downstream users, notably Transformer Engine. Instead, allow lowering rules to return singleton tuples and unwrap them if needed, but warn if this behavior is seen.

PiperOrigin-RevId: 650345051
2024-07-08 12:54:44 -07:00
Sergei Lebedev
a2a5068e5e Changed `pl.BlockSpec to accept block_shape before index_map`
So, instead of

    pl.BlockSpec(lambda i, j: ..., (42, 24))

``pl.BlockSpec`` now expects

    pl.BlockSpec((42, 24), lambda i, j: ...)

I will update Pallas tests in a follow up.

PiperOrigin-RevId: 648486321
2024-07-01 14:26:08 -07:00
Jake VanderPlas
fbcb157ad3 Finalize deprecation of several previously-deprecated jax.core functions:
- `jax.core.canonicalize_shape`
- `jax.core.dimension_as_value`
- `jax.core.definitely_equal`
- `jax.core.symbolic_equal_dim`

These have been raising deprecation warnings since JAX v0.4.24, released Feb 6 2024.

PiperOrigin-RevId: 647671122
2024-06-28 07:28:28 -07:00
jax authors
00528b9858 libdevice.10.bc is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
2024-06-27 08:35:59 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
vfdev-5
70b4823348 Updated jnp.ceil/floor/trunc to preserve int dtypes
Description:
- Updated jnp.ceil/floor/trunc to preserve int dtypes
- Updated tests
  - For integral dtypes but we can't yet today compare types vs numpy as numpy 2.0.0rc2 is not yet array api compliant in this case
2024-06-25 20:26:53 +02:00
Peter Hawkins
7f24837eef Update minimum NumPy version to v1.24. 2024-06-21 15:17:17 -07:00
Peter Hawkins
d7a22d3720 [JAX] Teach jit fast path how to handle negative static_argnums correctly.
PiperOrigin-RevId: 645172085
2024-06-20 15:18:25 -07:00
Yash Katariya
045ea944c8 Finish jax and jaxlib 0.4.30 release
PiperOrigin-RevId: 644402635
2024-06-18 08:56:39 -07:00
Yash Katariya
e6f26ff256 Deprecate jax.xla_computation. Use JAX AOT APIs to get the equivalent of jax.xla_computation functionality.
PiperOrigin-RevId: 644107276
2024-06-17 13:02:35 -07:00
Jake VanderPlas
9d5932a190 Deprecate passing of arrays in place of dtypes. 2024-06-14 05:40:04 -07:00
Jake VanderPlas
a92fa547a0 Re-land https://github.com/google/jax/pull/21847
Reverts 0bcc81ceb33e3065110e3dd56ca215dbb62f0a7b

PiperOrigin-RevId: 643202512
2024-06-13 19:53:53 -07:00
jax authors
8f5f8df112 Merge pull request #21863 from jakevdp:dep-tracer-hash
PiperOrigin-RevId: 643147305
2024-06-13 15:52:36 -07:00
jax authors
0bcc81ceb3 Reverts 5aedafc214cf930f5b196b1eb130fd7ec866bc5e
PiperOrigin-RevId: 643131144
2024-06-13 14:58:54 -07:00
jax authors
5aedafc214 Merge pull request #21847 from gnecula:export_deprecate
PiperOrigin-RevId: 643099957
2024-06-13 13:17:29 -07:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Yash Katariya
023bc7856b Add registration handler for TPU v5e in mesh_utils.
PiperOrigin-RevId: 643092629
2024-06-13 12:52:33 -07:00
George Necula
7af03a8fd1 [export] Deprecate jax.experimental.export
And announce jax.export.

While turning on the DeprecationWarning I discovered a couple
of tests that needed adjustment.
2024-06-13 21:46:18 +03:00
Jake VanderPlas
f63b94574a Deprecate internal pretty-printing APIs, jax.core.pp_* 2024-06-13 09:44:56 -07:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00