Update installation instructions.

Prefer installing via a wheel, because it's simpler to explain.
Describe exactly which packages are needed to build on Ubuntu 18.04.
This commit is contained in:
Peter Hawkins 2019-08-12 10:40:28 -04:00
parent 76a93eaa09
commit b6f0ec7a22

106
README.md
View File

@ -96,57 +96,15 @@ And for a deeper dive into JAX:
## Installation
JAX is written in pure Python, but it depends on XLA, which needs to be compiled
and installed as the `jaxlib` package. Use the following instructions to build
JAX from source or install a binary package with pip.
and installed as the `jaxlib` package. Use the following instructions to
install a binary package with `pip`, or to build JAX from source.
We support installing or building `jaxlib` on Linux and macOS platforms, but not
Windows. We're not currently working on Windows support, but contributions are
welcome (see [#438](https://github.com/google/jax/issues/438)).
### Building JAX from source
First, obtain the JAX source code, and make sure `scipy` is installed.
```bash
git clone https://github.com/google/jax
cd jax
pip install scipy
```
If you are building on a Mac, make sure XCode and the XCode command line tools
are installed.
To build XLA with CUDA support, you can run
```bash
python build/build.py --enable_cuda
pip install -e build # install jaxlib (includes XLA)
pip install -e . # install jax (pure Python)
```
See `python build/build.py --help` for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. The build
also depends on NumPy, and a compiler toolchain corresponding to that of
Ubuntu 16.04 or newer.
To build XLA without CUDA GPU support (CPU only), drop the `--enable_cuda`:
```bash
python build/build.py
pip install -e build # install jaxlib (includes XLA)
pip install -e . # install jax
```
To upgrade to the latest version from GitHub, just run `git pull` from the JAX
repository root, and rebuild by running `build.py` if necessary. You shouldn't have
to reinstall because `pip install -e` sets up symbolic links from site-packages
into the repository.
### pip installation
Installing XLA with prebuilt binaries via `pip` is still experimental,
especially with GPU support. Let us know on [the issue
tracker](https://github.com/google/jax/issues) if you run into any errors.
To install a CPU-only version, which might be useful for doing local
development on a laptop, you can run
@ -160,7 +118,7 @@ cloud VM), you can run
```bash
# install jaxlib
PYTHON_VERSION=cp27 # alternatives: cp27, cp35, cp36, cp37
PYTHON_VERSION=cp37 # alternatives: cp27, cp35, cp36, cp37
CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
@ -180,8 +138,64 @@ grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different pa
```
The Python version must match your Python interpreter. There are prebuilt wheels
for Python 2.7, 3.6, and 3.7; for anything else, you must build from source.
for Python 2.7, 3.5, 3.6, and 3.7; for anything else, you must build from
source.
Please let us know on [the issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the prebuilt wheels.
### Building JAX from source
First, obtain the JAX source code.
```bash
git clone https://github.com/google/jax
cd jax
```
You must also install some prerequisites:
* a C++ compiler (g++ or clang)
* Numpy
* Scipy
* Cython
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
```
sudo apt-get install g++ python python3-dev python3-numpy python3-scipy cython3
```
If you are building on a Mac, make sure XCode and the XCode command line tools
are installed.
You can also install the necessary Python dependencies using `pip`:
```
pip install numpy scipy cython
```
To build `jaxlib` with CUDA support, you can run
```bash
python build/build.py --enable_cuda
pip install -e build # installs jaxlib (includes XLA)
pip install -e . # installs jax (pure Python)
```
See `python build/build.py --help` for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. The build
also depends on NumPy, and a compiler toolchain corresponding to that of
Ubuntu 16.04 or newer.
To build `jaxlib` without CUDA GPU support (CPU only), drop the `--enable_cuda`:
```bash
python build/build.py
pip install -e build # installs jaxlib (includes XLA)
pip install -e . # installs jax
```
To upgrade to the latest version from GitHub, just run `git pull` from the JAX
repository root, and rebuild by running `build.py` if necessary. You shouldn't have
to reinstall because `pip install -e` sets up symbolic links from site-packages
into the repository.
## Running the tests