106 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
26fbeb6e2a Update WORKSPACE and libtpu version for jaxlib 0.3.15, take 3 2022-07-22 11:41:39 -07:00
Skye Wanderman-Milne
186a4f83e3
Update libtpu version for 0.3.15 release 2022-07-19 17:11:43 -07:00
Skye Wanderman-Milne
9149c38e1e Update WORKSPACE and setup.py in preparation for 0.3.15 jax/jaxlib release 2022-07-14 10:12:59 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Sharad Vikram
1daea700f2 Bump JAX/Jaxlib versions 2022-06-28 14:36:47 -07:00
Peter Hawkins
1e29b7b762 Update CHANGELOG.md and setup.py for 0.3.14 release. 2022-06-27 09:38:41 -04:00
Peter Hawkins
a560a29e12 Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00
Sharad Vikram
9bd1bd67e0 Update versions for jax/jaxlib release 2022-06-21 12:57:28 -07:00
Yash Katariya
e0ff842c2a Use epath from etils package. This CL also makes epath a required dep for JAX.
This is being used in the following ways in this CL:

* To dump IR, you can now pass paths with `gs://` or `cns` and the HLO can be dumped to those paths.
* Removing the TF dep from gda serialization.

PiperOrigin-RevId: 452117007
2022-05-31 12:47:35 -07:00
Jeppe Klitgaard
a11f15e3ec feat: officially support Python 3.10 2022-05-07 13:43:12 +01:00
Yash Katariya
d2e3d4278d Updates values after jax and jaxlib 0.3.10 release
PiperOrigin-RevId: 446623299
2022-05-04 21:17:37 -07:00
Yash Katariya
38ce6d027b Update TF commit for release
PiperOrigin-RevId: 446555288
2022-05-04 14:42:50 -07:00
Yash Katariya
ff1a3c40ba jax and jaxlib release
PiperOrigin-RevId: 446295827
2022-05-03 14:52:40 -07:00
Peter Hawkins
38ea5a6bc0 Copybara import of the project:
--
391dea76bc8fe264cf26ec93d42147f87847894d by Peter Hawkins <phawkins@google.com>:

Update version numbers after jax/jaxlib 0.3.7 release.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/10324 from hawkinsp:jaxlib 391dea76bc8fe264cf26ec93d42147f87847894d
PiperOrigin-RevId: 442311051
2022-04-16 22:37:09 -07:00
Peter Hawkins
52a97f2e06 Jax 0.3.7 and jaxlib 0.3.7 release. 2022-04-15 12:02:05 -04:00
Yash Katariya
5fd78eaf02
Bump the libtpu version to prepare for JAX release 2022-04-12 11:41:07 -07:00
Peter Hawkins
4dc69034b0 Update version numbers after jax/jaxlib release. 2022-04-07 16:40:19 -04:00
Peter Hawkins
7f751c5523 Update libtpu version for jax 0.3.5 release. 2022-04-07 16:14:13 -04:00
Peter Hawkins
96ba290faf Jax 0.3.5 and jaxlib 0.3.5 release. 2022-04-06 23:56:41 +00:00
Skye Wanderman-Milne
d7087abce6 Bump jax and jaxlib versions for 0.3.2 release
Also add CPU pjit to changelog
2022-03-16 14:31:00 -07:00
Skye Wanderman-Milne
f9775a2ced Update CHANGELOG and setup.py for jax + jaxlib 0.3.2 releases 2022-03-16 10:17:42 -07:00
Yash Katariya
2162868ed9 Update values after release
PiperOrigin-RevId: 427910510
2022-02-10 20:32:53 -08:00
Yash Katariya
1ad3551ec9 Release jax and jaxlib 0.3.0 as per the new release process.
PiperOrigin-RevId: 427809845
2022-02-10 11:59:13 -08:00
Yash Katariya
d82bcc2a0c Add jax[_ci] option to account for the new release process.
PiperOrigin-RevId: 427802081
2022-02-10 11:29:51 -08:00
Skye Wanderman-Milne
d096d9a758 Update pinned libtpu-nightly version for jaxlib 0.1.76 2022-01-28 17:29:04 -08:00
Peter Hawkins
6791446bb1 Update development jaxlib version to 0.1.77, update jaxlib version in setup.py to 0.1.76.
Changelog entry for jaxlib 0.1.77 was already added in a previous PR.

PiperOrigin-RevId: 424872047
2022-01-28 08:10:58 -08:00
Peter Hawkins
04369a3588 Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00
Yash Katariya
1b5630eed6 Update jaxlib version number to 0.1.76
PiperOrigin-RevId: 415050863
2021-12-08 11:14:12 -08:00
Peter Hawkins
7902ddaca2 Update jaxlib versions. 2021-11-17 11:46:41 -05:00
Yash Katariya
ee752b32f7 Use cuda11_cudnn82 instead of cuda=11,cudnn=82 because the latter one is a syntax error
PiperOrigin-RevId: 404240654
2021-10-19 06:24:53 -07:00
Yash Katariya
4d8bce1b85 Add a default cuda installation path and more explicit installation paths for CUDA jaxlib.
```
# Installs Cuda 11 with Cudnn 8.2
$ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

PiperOrigin-RevId: 404134291
2021-10-18 19:56:22 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Yash Katariya
66a4a9ff3f Remove 10.2 cuda support
PiperOrigin-RevId: 402707900
2021-10-12 18:44:07 -07:00
Skye Wanderman-Milne
0072c32546 Update CHANGELOG and verson numbers for jaxlib 0.1.72 release 2021-10-12 17:37:29 -07:00
yashkatariya
be824a792e Update files after new jaxlib release 0.1.71 2021-09-01 10:43:20 -07:00
Jake VanderPlas
062f8d2261 Specify scipy in setup.py install_requires 2021-08-20 10:16:17 -07:00
Yash Katariya
bf967d88d8 Upgrade versions after jaxlib release
PiperOrigin-RevId: 389753047
2021-08-09 16:37:44 -07:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Skye Wanderman-Milne
b2fd6a772b Changes to make jax[tpu] work better in a docker container.
1. In cloud_tpu_init.py, check whether we're on a Cloud TPU VM by
   looking for the libtpu Python package, instead of /lib/libtpu.so
   (which isn't necessarily present in a docker container). JAX now
   relies on the libtpu package instead of the system libtpu.so, so
   this makes more sense either way. This means we'll try/catch an
   ImportError in all non-TPU environments when importing jax, which
   hopefully isn't noticeably slow.

2. Add requests as a jax[tpu] dependency, since it's needed by
   cloud_tpu_init.py. This comes pre-installed on Cloud TPU VMs, but
   may not be installed in docker containers, virtualenvs, etc.

I manually tested by creating the following Dockerfile on a Cloud TPU VM:
```
FROM ubuntu:18.04
RUN apt update && apt install git python3-pip -y
RUN git clone https://github.com/skye/jax && cd jax && git checkout tpu_docker
WORKDIR jax
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
CMD ["python3", "-c", "import jax; print(jax.device_count())"]
```

And then running the following commands:
```
$ sudo docker build -t jax-test .
$ sudo docker run --privileged jax-test
8
```

Note the `--privileged` flags is necessary to let the container access
the TPU devices in /dev.
2021-07-12 17:42:46 -07:00
Qiao Zhang
82e74959fe Update changelog for jaxlib-0.1.69. 2021-07-12 12:06:41 -07:00
Jake VanderPlas
5badea53ae setup.py: add Python 3.9 classifier 2021-07-09 12:42:39 -07:00
Jake VanderPlas
6be875f1f3 setup.py: change jax[cpu] to target the current jaxlib version 2021-07-02 09:49:49 -07:00
Skye Wanderman-Milne
55276d15e4 Fix pip install jax[tpu]
* Updates jax_releases.html index to include libtpu wheels
* Change [tpu] extras to specify `libtpu-nightly` instead of wheel URL

The full install command will now be:
`pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html`
(similar to the cuda install commands)

I've already pushed an updated jax_releases.html to the jax-releases GCS bucket.
2021-06-23 14:13:15 -07:00
Skye Wanderman-Milne
2460f91561 Revert "Improve support for pip install jax[cuda111]"
This reverts commit dfcaf0feb81db0d2594b9a78154e8679b0366533.
2021-06-23 14:08:51 -07:00
Qiao Zhang
a22841b6bb Bump jaxlib ver to 0.1.68. 2021-06-23 12:37:56 -07:00
Jake VanderPlas
dfcaf0feb8 Improve support for pip install jax[cuda111] 2021-06-23 11:42:04 -07:00
Skye Wanderman-Milne
a12a229546 Add pip install jax[tpu] configuration.
This can be used on Cloud TPU VMs to automatically install compatible
versions of jax, jaxlib, and libtpu (the low-level library JAX uses to
access the TPU on Cloud TPU VMs).

The new install command requires a new jax release (`>=0.2.15`) and
jaxlib release (`>=0.1.68`) to work, since it requires both
cdfbd9dde1
and
ce2bc24996
to pick up the pip-installed libtpu. I'll update the README and Cloud
TPU VM documentation once these releases are out.
2021-06-23 02:26:27 +00:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
mariosasko
55b421ff36 Specify zip_safe for mypy 2021-06-10 16:06:11 +02:00