688 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
f55141ef0e Fix listing of vectorized deprecation in changelog.
As noted in https://github.com/jax-ml/jax/pull/23881, that change didn't
actually make it in in time for the v0.4.34 release so I've moved it to
the v0.4.35 section.
2024-10-10 15:40:01 -04:00
Peter Hawkins
aa3254d723 Deprecate jax.lib.xla_client.PaddingType.
This type is unused by JAX, so there is no replacement.

(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)

PiperOrigin-RevId: 684451556
2024-10-10 08:22:20 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Yuxuan Jiang
757a77ede0
Fix wrong date in changelog 2024-10-06 23:16:30 +08:00
George Necula
db89c245ac [host_callback] Remove most of the jax.experimental.host_callback module
These APIs have been deprecated since March 2024 and they are subsumed by the new JAX external callbacks.
See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 682830525
2024-10-06 01:10:34 -07:00
Jake VanderPlas
45f0e9ad68 Simplify definition of jnp.isscalar
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
2024-10-05 07:12:20 -07:00
Peter Hawkins
b0b7a60e63 Merge branch 'release/0.4.34' 2024-10-04 10:56:18 -04:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Jake VanderPlas
49ad220e57 Finalize deprecation of XLACompatibleSharding
PiperOrigin-RevId: 681156145
2024-10-01 14:02:34 -07:00
George Necula
2228115cf4 [host_callback] Flip the JAX_HOST_CALLBACK_LEGACY flag to False
`jax.experimental.host_callback` has been deprecated since March 2024
 (JAX version 0.4.26). Now we set the default value of the `--jax_host_callback_legacy` configuration value to `True`, which means that if your code uses `jax.experimental.host_callback` APIs, those API calls will be implemented in terms of the new `jax.experimental.io_callback` API.

If this breaks your code, for a very limited time, you can set the `--jax_host_callback_legacy` to `True`. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs.

See https://github.com/google/jax/issues/20385 for a discussion.

PiperOrigin-RevId: 681004255
2024-10-01 07:07:29 -07:00
carlosgmartin
65a58d622c Edit implementation of jax.numpy.ldexp to get correct gradient. 2024-09-30 18:27:39 -04:00
Peter Hawkins
0e082f978b Deprecate jax.lib.xla_client.Device.
jax.Device is a longstanding public name for this class.

PiperOrigin-RevId: 679197718
2024-09-26 10:17:04 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Jake VanderPlas
e05c37c667 Finalize deprecation of pretty-printing utils in jax.core.pp_*
PiperOrigin-RevId: 678775782
2024-09-25 11:20:35 -07:00
Peter Hawkins
111f13e279 Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
PiperOrigin-RevId: 678748794
2024-09-25 10:14:45 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Peter Hawkins
6a3736a1d7 Add a note to the changelog about the new CPU thunks backend, enabled in 0.4.32. 2024-09-19 15:38:52 -04:00
Peter Hawkins
bef36c431d Add Python 3.13 wheels to changelog. 2024-09-18 18:57:03 +00:00
rajasekharporeddy
2714469397 Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros 2024-09-18 17:06:28 +05:30
Peter Hawkins
ae0e403c60 Merge release/0.4.33 into main and update version numbers. 2024-09-16 18:46:24 +00:00
Peter Hawkins
80e1c94de6 Prepare for v0.4.33 release.
This release is branched off the v0.4.32 release, with two changes:
a) a fixed libtpu pin, and
b) a patch to revert an F64 tanh issue on CPU.
2024-09-16 13:30:35 +00:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Peter Hawkins
dffac29e63 Reverts 255c30303d32e7473262b2e35348175c87e4348f
PiperOrigin-RevId: 674083626
2024-09-12 18:14:25 -07:00
Peter Hawkins
255c30303d Fix a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

PiperOrigin-RevId: 674019460
2024-09-12 14:49:18 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
Parker Schuh
bf2237a102 Flip jax_pmap_no_rank_reduction by default to True.
This changes:
* The performance of array[0] (use array[0:1] instead).
* The shape of jax_array.addressable_shards or jax_array.addressable_data(0) of arrays that come from pmap.

PiperOrigin-RevId: 673564995
2024-09-11 15:41:47 -07:00
Peter Hawkins
3e81ae530d Update version numbers after v0.4.32 release. 2024-09-11 16:18:56 -04:00
Dan Foreman-Mackey
bcbc0962bb Add the FFI functions and tutorial to the changelog.
Although we soft launched the FFI with v0.4.31, it would be nice to
include an update in the changelog to help with visibility.
2024-09-06 12:30:28 -04:00
Peter Hawkins
9c86fdec02 Make optimization_barrier a public lax API. 2024-09-06 00:18:57 +00:00
Sergei Lebedev
1289640f09 Deprecated calling `jax.dlpack.from_dlpack` with a DLPack tensor
PiperOrigin-RevId: 670723176
2024-09-03 15:16:02 -07:00
Jake VanderPlas
f2ffe7f8f2 Deprecate jax.numpy.round_
NumPy removed np.round in version 2.0; jax.numpy.round is drop-in
replacement.
2024-09-03 06:52:07 -07:00
Jake VanderPlas
a3d6cf007e First pass at ufunc interfaces for several jax.numpy functions 2024-08-30 11:53:02 -07:00
Sergei Lebedev
02bb884357 `jax.tree_util.register_dataclass now validates data_fields and meta_fields`
A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity

PiperOrigin-RevId: 669244669
2024-08-30 02:01:50 -07:00
rajasekharporeddy
ced012f5ed Update jnp.fabs to emulate the behavior of np.fabs for complex inputs 2024-08-28 20:16:09 +05:30
Bryan Massoth
b38f985b01 Add a callout that LibTPU now supports profiling of SparseCore for TPUv5p chips which will be viewable in Tensorboard Profiler's TraceViewer tool.
PiperOrigin-RevId: 667708094
2024-08-26 14:04:43 -07:00
jax authors
acf4b32452 Merge pull request #23060 from jakevdp:core-deps
PiperOrigin-RevId: 662988768
2024-08-14 11:21:00 -07:00
jax authors
599c13aa09 Introduce hermetic CUDA in Google ML projects.
1) Hermetic CUDA rules allow building wheels with GPU support on a machine without GPUs, as well as running Bazel GPU tests on a machine with only GPUs and NVIDIA driver installed. When `--config=cuda` is provided in Bazel options, Bazel will download CUDA, CUDNN and NCCL redistributions in the cache, and use them during build and test phases.

    [Default location of CUNN redistributions](https://developer.download.nvidia.com/compute/cudnn/redist/)

    [Default location of CUDA redistributions](https://developer.download.nvidia.com/compute/cuda/redist/)

    [Default location of NCCL redistributions](https://pypi.org/project/nvidia-nccl-cu12/#history)

2) To include hermetic CUDA rules in your project, add the following in the WORKSPACE of the downstream project dependent on XLA.

   Note: use `@local_tsl` instead of `@tsl` in Tensorflow project.

   ```
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
      "cuda_json_init_repository",
   )

   cuda_json_init_repository()

   load(
      "@cuda_redist_json//:distributions.bzl",
      "CUDA_REDISTRIBUTIONS",
      "CUDNN_REDISTRIBUTIONS",
   )
   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
      "cuda_redist_init_repositories",
      "cudnn_redist_init_repository",
   )

   cuda_redist_init_repositories(
      cuda_redistributions = CUDA_REDISTRIBUTIONS,
   )

   cudnn_redist_init_repository(
      cudnn_redistributions = CUDNN_REDISTRIBUTIONS,
   )

   load(
      "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
      "cuda_configure",
   )

   cuda_configure(name = "local_config_cuda")

   load(
      "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
      "nccl_redist_init_repository",
   )

   nccl_redist_init_repository()

   load(
      "@tsl//third_party/nccl/hermetic:nccl_configure.bzl",
      "nccl_configure",
   )

   nccl_configure(name = "local_config_nccl")
   ```

PiperOrigin-RevId: 662981325
2024-08-14 10:58:43 -07:00
Jake VanderPlas
bd9698ec6d Deprecate several internal utilities in jax.core 2024-08-14 10:06:13 -07:00
Dan Foreman-Mackey
60bf5b7727 Add a jax.process_indices function.
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.

PiperOrigin-RevId: 662142636
2024-08-12 10:30:41 -07:00
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Yue Sheng
f255fb700a Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False), one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 659741822
2024-08-05 17:48:17 -07:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
Jake VanderPlas
3551fcc077 Deprecate several APIs in jax.lib.xla_bridge
PiperOrigin-RevId: 658274719
2024-07-31 23:00:35 -07:00
Jake VanderPlas
8bcd288621 Raise ValueError for complex inputs to jnp.clip and jnp.hypot.
Such inputs were deprecated in JAX v0.4.27, and have been raising a DeprecationWarning for the last several releases.

PiperOrigin-RevId: 657717875
2024-07-30 13:49:37 -07:00
jax authors
256956ad58 Merge pull request #22704 from gnecula:pallas_better_errors
PiperOrigin-RevId: 657571604
2024-07-30 06:39:44 -07:00
Peter Hawkins
c1cd7f9e2d Drop support for mhlo in JAX's public API.
PiperOrigin-RevId: 657551590
2024-07-30 05:29:52 -07:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
Yash Katariya
2106a25977 Finish jax and jaxlib v0.4.31 release
PiperOrigin-RevId: 657388782
2024-07-29 17:57:37 -07:00
Peter Hawkins
d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00