77 Commits

Author SHA1 Message Date
Yash Katariya
ee752b32f7 Use cuda11_cudnn82 instead of cuda=11,cudnn=82 because the latter one is a syntax error
PiperOrigin-RevId: 404240654
2021-10-19 06:24:53 -07:00
Yash Katariya
4d8bce1b85 Add a default cuda installation path and more explicit installation paths for CUDA jaxlib.
```
# Installs Cuda 11 with Cudnn 8.2
$ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

PiperOrigin-RevId: 404134291
2021-10-18 19:56:22 -07:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Yash Katariya
66a4a9ff3f Remove 10.2 cuda support
PiperOrigin-RevId: 402707900
2021-10-12 18:44:07 -07:00
Skye Wanderman-Milne
0072c32546 Update CHANGELOG and verson numbers for jaxlib 0.1.72 release 2021-10-12 17:37:29 -07:00
yashkatariya
be824a792e Update files after new jaxlib release 0.1.71 2021-09-01 10:43:20 -07:00
Jake VanderPlas
062f8d2261 Specify scipy in setup.py install_requires 2021-08-20 10:16:17 -07:00
Yash Katariya
bf967d88d8 Upgrade versions after jaxlib release
PiperOrigin-RevId: 389753047
2021-08-09 16:37:44 -07:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Skye Wanderman-Milne
b2fd6a772b Changes to make jax[tpu] work better in a docker container.
1. In cloud_tpu_init.py, check whether we're on a Cloud TPU VM by
   looking for the libtpu Python package, instead of /lib/libtpu.so
   (which isn't necessarily present in a docker container). JAX now
   relies on the libtpu package instead of the system libtpu.so, so
   this makes more sense either way. This means we'll try/catch an
   ImportError in all non-TPU environments when importing jax, which
   hopefully isn't noticeably slow.

2. Add requests as a jax[tpu] dependency, since it's needed by
   cloud_tpu_init.py. This comes pre-installed on Cloud TPU VMs, but
   may not be installed in docker containers, virtualenvs, etc.

I manually tested by creating the following Dockerfile on a Cloud TPU VM:
```
FROM ubuntu:18.04
RUN apt update && apt install git python3-pip -y
RUN git clone https://github.com/skye/jax && cd jax && git checkout tpu_docker
WORKDIR jax
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install .[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
CMD ["python3", "-c", "import jax; print(jax.device_count())"]
```

And then running the following commands:
```
$ sudo docker build -t jax-test .
$ sudo docker run --privileged jax-test
8
```

Note the `--privileged` flags is necessary to let the container access
the TPU devices in /dev.
2021-07-12 17:42:46 -07:00
Qiao Zhang
82e74959fe Update changelog for jaxlib-0.1.69. 2021-07-12 12:06:41 -07:00
Jake VanderPlas
5badea53ae setup.py: add Python 3.9 classifier 2021-07-09 12:42:39 -07:00
Jake VanderPlas
6be875f1f3 setup.py: change jax[cpu] to target the current jaxlib version 2021-07-02 09:49:49 -07:00
Skye Wanderman-Milne
55276d15e4 Fix pip install jax[tpu]
* Updates jax_releases.html index to include libtpu wheels
* Change [tpu] extras to specify `libtpu-nightly` instead of wheel URL

The full install command will now be:
`pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_releases.html`
(similar to the cuda install commands)

I've already pushed an updated jax_releases.html to the jax-releases GCS bucket.
2021-06-23 14:13:15 -07:00
Skye Wanderman-Milne
2460f91561 Revert "Improve support for pip install jax[cuda111]"
This reverts commit dfcaf0feb81db0d2594b9a78154e8679b0366533.
2021-06-23 14:08:51 -07:00
Qiao Zhang
a22841b6bb Bump jaxlib ver to 0.1.68. 2021-06-23 12:37:56 -07:00
Jake VanderPlas
dfcaf0feb8 Improve support for pip install jax[cuda111] 2021-06-23 11:42:04 -07:00
Skye Wanderman-Milne
a12a229546 Add pip install jax[tpu] configuration.
This can be used on Cloud TPU VMs to automatically install compatible
versions of jax, jaxlib, and libtpu (the low-level library JAX uses to
access the TPU on Cloud TPU VMs).

The new install command requires a new jax release (`>=0.2.15`) and
jaxlib release (`>=0.1.68`) to work, since it requires both
cdfbd9dde1
and
ce2bc24996
to pick up the pip-installed libtpu. I'll update the README and Cloud
TPU VM documentation once these releases are out.
2021-06-23 02:26:27 +00:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
mariosasko
55b421ff36 Specify zip_safe for mypy 2021-06-10 16:06:11 +02:00
Skye Wanderman-Milne
63dbb99a66 Update README, etc. for jaxlib 0.1.67 release 2021-05-17 17:48:46 -07:00
Qiao Zhang
528d5bbb11 Update README etc for jaxlib 0.1.66 release. 2021-05-11 16:49:32 -07:00
Jake VanderPlas
bdcce8f8f4 setup.py: make cuda extras specifier more consistent 2021-04-26 09:22:52 -07:00
Skye Wanderman-Milne
f8f373466c Update README, etc. for jaxlib 0.1.65 release 2021-04-07 17:51:20 -07:00
Jake VanderPlas
f9a4162551 Specify minimum jaxlib version in a single location 2021-03-22 16:14:41 -07:00
Skye Wanderman-Milne
0cbe2c1c05 Update README, etc. for jaxlib 0.1.64 release 2021-03-18 16:11:40 -07:00
Skye Wanderman-Milne
757247b791 Update README, etc. for jaxlib 0.1.63 release 2021-03-17 10:14:52 -07:00
Jake VanderPlas
d6408a4e6a Add extras_require to setup.py 2021-03-16 13:23:46 -07:00
George Necula
f105517ea2 Fixed mypy type errors for numpy 1.20
Revert also previous changes that pinned numpy to 1.19.

One of the changes in numpy 1.20 is to add more type annotations.
However, this sometimes make mypy give errors. A common example is
numpy.take, which with the new type annotation does not appear to
mypy as indexable.

Another change is that np.int and np.bool are deprecated. One
should use np.bool_ or np.int_, or the built-ins bool and int.
2021-02-05 10:40:47 +02:00
George Necula
a145e3d414 Pin numpy to max version 1.19, to avoid errors with 1.20
Will fix the numpy errors separately.
2021-01-31 15:18:54 +02:00
Peter Hawkins
328ddfca9f Add py.typed to setup.py for PEP 561 compliance. 2021-01-14 15:08:53 -05:00
Jake VanderPlas
5959fa9ccf Run main test suite under Python 3.8 2020-12-30 13:20:13 -08:00
John Aslanides
6029f02be8 Add setup classifiers fields (including supported Python versions) for PyPI.
This is following https://packaging.python.org/tutorials/packaging-projects/.
2020-12-30 11:29:34 +00:00
Srijan Saurav
40e20242db
Fix code quality issues (#4302)
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
2020-09-17 09:21:18 -07:00
Peter Hawkins
f036f5ddb0
Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation war… (#3543)
* Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation warnings.

* Pin a newer tf-nightly to fix jax2tf tests for NumPy 1.19.0
2020-06-24 15:19:00 -04:00
Matthew Johnson
3fff837907
pin numpy version in setup.py to avoid warnings (#3509) 2020-06-22 08:12:41 -07:00
Peter Hawkins
cbdf9a5a43
Drop support for Python 3.5. (#2445) 2020-03-18 10:54:28 -04:00
Peter Hawkins
681ba37f7e
Drop fastcache dependency, which isn't necessary on Python 3. (#1995)
Drop protobuf and six dependencies from travis configuration.
2020-01-14 10:08:23 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Peter Hawkins
4fc765241f
Drop protobuf dependency from jax package. It appears unused. (#1700) 2019-11-15 14:55:26 -05:00
Skye Wanderman-Milne
90093b7824 Remove version restriction from opt_einsum.
See https://github.com/dgasmith/opt_einsum/issues/98.
2019-08-23 14:43:52 -07:00
Skye Wanderman-Milne
921096e32e Require opt_einsum version to be less than 3.0.0.
opt_einsum 3.0.0 adds a jax backend, which raises an exception on import.
2019-08-19 19:28:07 -07:00
Peter Hawkins
8e66d29c45 Suppress flake8 warning from __version__ logic. 2019-08-04 12:12:53 -04:00
Peter Hawkins
08013954a4 Use fastcache for LRU caches in JAX.
fastcache is both a faster cache implementation and is also thread-safe.
2019-07-22 17:24:10 -04:00
Matthew Johnson
299977eeef exclude examples dir from setup.py find_packages
fixes #582
2019-04-06 14:16:57 -07:00
Peter Hawkins
f939ac078d Update XLA.
Updates XLA to 00afc7bb81.

The new XLA release removes the use of protocol buffers from the XLA client. Fixes #349.
Add backward compatibility shims to jaxlib to allow older jax releases to still work on an up to date jaxlib.

The new XLA release also incorporates a fix that avoids a host-device copy for every iteration of a `lax.fori_loop()` on GPU. Fixes #402.

Add a new jaxlib.__version__ field, change jax/jaxlib compatibility logic to check for it.
2019-02-26 06:07:44 -08:00
Matthew Johnson
9a9c304644 add version attribute
following idea 3 here:
https://packaging.python.org/guides/single-sourcing-package-version/
2019-02-13 20:04:38 -08:00
Matthew Johnson
9cd28d12a1 bump version for pypi 2019-02-13 14:55:23 -08:00
Matthew Johnson
a75d1c6e08 bump version for pypi 2019-02-05 19:06:29 -08:00