181 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
011fc88c03 Update versions and changelog for jax 0.4.14 release 2023-07-27 16:22:53 -07:00
Skye Wanderman-Milne
0b24b2ba6a Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.14 release 2023-07-27 13:35:04 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
bfa113ba60 Remove references to Python 3.8.
Remove the old build scripts/Dockerfile, since they are unused and broken.

PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Skye Wanderman-Milne
20095ab9da Update version numbers and changelog post 0.4.13 release 2023-06-22 17:54:57 -07:00
Peter Hawkins
487b640acf Jax 0.4.13 release. 2023-06-22 14:59:36 -07:00
Peter Hawkins
0ded163027 Fix cuda12 pip install.
The wheel is now called cudnn89.

Fixes #16362
2023-06-12 15:38:16 -04:00
Skye Wanderman-Milne
4b80103077 Update versions and changelog post 0.4.12 release 2023-06-08 16:53:22 -07:00
Skye Wanderman-Milne
8223e452f2 Update WORKSPACE and setup.py for jax + jaxlib 0.4.12 release 2023-06-08 13:58:09 -07:00
Yash Katariya
4c48611fba Finish jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536931532
2023-05-31 23:49:32 -07:00
Yash Katariya
48ad9a6f3e Start jax and jaxlib 0.4.11 release
PiperOrigin-RevId: 536860076
2023-05-31 16:48:52 -07:00
Skye Wanderman-Milne
968237080f Add importlib_metadata to project requirements.
This is necessary to ensure we can correctly detect PJRT plugins via
entry_points without compatibility errors.

Prior to this change, there was conditional logic to handle if
importlib_metadata wasn't installed at all. However, it doesn't handle
the case where importlib_metadata is installed by not high enough
version to support Python 3.10 compat. This change gets rid of that
logic and just ensures the right version is installed.

All of this logic can be removed if/when jax requires Python version
>= 3.10

This also removes an unnecessary `requests` dep for the [tpu] install.
2023-05-31 21:03:12 +00:00
Peter Hawkins
69cf67f252 Bump the minimum CUDNN version for CUDA 12 wheels to 8.9. 2023-05-26 10:04:34 -04:00
Peter Hawkins
2b7790290b Bump minimum CUDNN version in pip installation to 8.8.
There are known wrong output bugs observed in JAX for earlier versions, in particular related to RNNs.
2023-05-25 14:46:39 -04:00
Skye Wanderman-Milne
533a7c05f1 Update versions and changelog post 0.4.10 release 2023-05-11 18:16:02 -07:00
Skye Wanderman-Milne
82bbeef519 Update setup.py, WORKSPACE, and CHANGELOG for jax/jaxlib 0.4.10 release 2023-05-11 14:46:06 -07:00
Skye Wanderman-Milne
b02b043e7f Update versions and changelog for 0.4.9 release 2023-05-09 17:06:59 -07:00
Skye Wanderman-Milne
5bcd9dcc46 Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release, take 2 2023-05-09 14:49:54 -07:00
Skye Wanderman-Milne
5e9364abc6 Revert setup.py changes.
This reverts the setup.py changes from
f28b20175f307d5a56502446a9706480126a5bd4. We actually need to fix some
more issues before releasing 0.4.9, so fix the install at HEAD in the
meantime.
2023-05-08 09:58:51 -07:00
Skye Wanderman-Milne
f28b20175f Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.9 release 2023-05-04 14:38:46 -07:00
Jake VanderPlas
59e6ed213e Use ml_dtypes definition for jnp.finfo 2023-05-04 10:40:44 -07:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Peter Hawkins
75d0f6522d Add cupti pip dependency, needed for GPU profiling.
Issue https://github.com/google/jax/issues/15384

PiperOrigin-RevId: 521841461
2023-04-04 12:55:36 -07:00
Peter Hawkins
705b5cc000 Add version constraints to CUDA pip wheel dependencies.
Fixes https://github.com/google/jax/issues/15267
2023-03-28 21:55:32 -04:00
Yash Katariya
670fba3a91 Finish jax and jaxlib 0.4.7 release
PiperOrigin-RevId: 519839723
2023-03-27 15:06:38 -07:00
Yash Katariya
e9cac5eb47 Prepare for jax and jaxlib 0.4.7 release
PiperOrigin-RevId: 519785176
2023-03-27 11:45:22 -07:00
Peter Hawkins
40fb646e35 Fix duplicate definition of 'cuda' extra in setup.py.
PiperOrigin-RevId: 519750659
2023-03-27 09:52:37 -07:00
Peter Hawkins
b62f114524 Add support for using pip-installed CUDA wheels.
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
8bb90b5fbe [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
PiperOrigin-RevId: 518830467
2023-03-23 05:13:39 -07:00
Skye Wanderman-Milne
6560bf8c36 Update versions and changelog for jax + jaxlib 0.4.6 release 2023-03-09 14:50:42 -08:00
Skye Wanderman-Milne
1aa08fd4a0 Update WORKSPACE and setup.py for jax/jaxlib 0.4.6 release 2023-03-09 10:44:58 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Yash Katariya
941722f7db Finish jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510234171
2023-02-16 13:54:56 -08:00
Yash Katariya
58e46b48e6 Prepare for jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510152471
2023-02-16 08:37:15 -08:00
Skye Wanderman-Milne
21f12183bf Post 0.4.3 release updates 2023-02-08 10:08:59 -08:00
Skye Wanderman-Milne
8ab158574d Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release 2023-02-07 15:45:28 -08:00
Skye Wanderman-Milne
b6a8aa6394 Update versions for jaxlib 0.4.2 release.
I also screwed up the CHANGELOG before (I shouldn't have added a
date), so I'm fixing the dates now.
2023-01-26 01:12:06 +00:00
Skye Wanderman-Milne
c4ad27c363 Update libtpu version for jaxlib 0.4.2 release (again) 2023-01-25 01:34:16 +00:00
Skye Wanderman-Milne
3f4bd5f449 Updates for jax + jaxlib 0.4.2 release 2023-01-20 19:04:46 +00:00
Jake VanderPlas
9d100ae9f4 Explicitly set utf-8 encoding in setup.py 2023-01-05 09:41:18 -08:00
Yash Katariya
835d0c979a Finish jax and jaxlib 0.4.2 release
PiperOrigin-RevId: 495068000
2022-12-13 10:51:13 -08:00
Yash Katariya
c4d590b1b6 Update values for release 0.4.1
PiperOrigin-RevId: 494889744
2022-12-12 19:04:38 -08:00
Yash Katariya
0bdb7ec042 Finish jax and jaxlib release 0.4.0
PiperOrigin-RevId: 494833878
2022-12-12 14:43:35 -08:00
Yash Katariya
0118f8d568 Prepare for jax and jaxlib 0.4.0 release
PiperOrigin-RevId: 493733609
2022-12-07 16:02:24 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Yash Katariya
25d1a0b4c6 Add cudnn 86 (for cuda 11.8) so that I can release cuda 11.8 nightlies.
PiperOrigin-RevId: 493086060
2022-12-05 12:50:09 -08:00