mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Update installation instructions.
Prefer installing via a wheel, because it's simpler to explain. Describe exactly which packages are needed to build on Ubuntu 18.04.
This commit is contained in:
parent
76a93eaa09
commit
b6f0ec7a22
106
README.md
106
README.md
@ -96,57 +96,15 @@ And for a deeper dive into JAX:
|
||||
|
||||
## Installation
|
||||
JAX is written in pure Python, but it depends on XLA, which needs to be compiled
|
||||
and installed as the `jaxlib` package. Use the following instructions to build
|
||||
JAX from source or install a binary package with pip.
|
||||
and installed as the `jaxlib` package. Use the following instructions to
|
||||
install a binary package with `pip`, or to build JAX from source.
|
||||
|
||||
We support installing or building `jaxlib` on Linux and macOS platforms, but not
|
||||
Windows. We're not currently working on Windows support, but contributions are
|
||||
welcome (see [#438](https://github.com/google/jax/issues/438)).
|
||||
|
||||
### Building JAX from source
|
||||
First, obtain the JAX source code, and make sure `scipy` is installed.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/google/jax
|
||||
cd jax
|
||||
pip install scipy
|
||||
```
|
||||
|
||||
If you are building on a Mac, make sure XCode and the XCode command line tools
|
||||
are installed.
|
||||
|
||||
To build XLA with CUDA support, you can run
|
||||
|
||||
```bash
|
||||
python build/build.py --enable_cuda
|
||||
pip install -e build # install jaxlib (includes XLA)
|
||||
pip install -e . # install jax (pure Python)
|
||||
```
|
||||
|
||||
See `python build/build.py --help` for configuration options, including ways to
|
||||
specify the paths to CUDA and CUDNN, which you must have installed. The build
|
||||
also depends on NumPy, and a compiler toolchain corresponding to that of
|
||||
Ubuntu 16.04 or newer.
|
||||
|
||||
To build XLA without CUDA GPU support (CPU only), drop the `--enable_cuda`:
|
||||
|
||||
```bash
|
||||
python build/build.py
|
||||
pip install -e build # install jaxlib (includes XLA)
|
||||
pip install -e . # install jax
|
||||
```
|
||||
|
||||
To upgrade to the latest version from GitHub, just run `git pull` from the JAX
|
||||
repository root, and rebuild by running `build.py` if necessary. You shouldn't have
|
||||
to reinstall because `pip install -e` sets up symbolic links from site-packages
|
||||
into the repository.
|
||||
|
||||
### pip installation
|
||||
|
||||
Installing XLA with prebuilt binaries via `pip` is still experimental,
|
||||
especially with GPU support. Let us know on [the issue
|
||||
tracker](https://github.com/google/jax/issues) if you run into any errors.
|
||||
|
||||
To install a CPU-only version, which might be useful for doing local
|
||||
development on a laptop, you can run
|
||||
|
||||
@ -160,7 +118,7 @@ cloud VM), you can run
|
||||
|
||||
```bash
|
||||
# install jaxlib
|
||||
PYTHON_VERSION=cp27 # alternatives: cp27, cp35, cp36, cp37
|
||||
PYTHON_VERSION=cp37 # alternatives: cp27, cp35, cp36, cp37
|
||||
CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100
|
||||
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
|
||||
BASE_URL='https://storage.googleapis.com/jax-releases'
|
||||
@ -180,8 +138,64 @@ grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different pa
|
||||
```
|
||||
|
||||
The Python version must match your Python interpreter. There are prebuilt wheels
|
||||
for Python 2.7, 3.6, and 3.7; for anything else, you must build from source.
|
||||
for Python 2.7, 3.5, 3.6, and 3.7; for anything else, you must build from
|
||||
source.
|
||||
|
||||
Please let us know on [the issue tracker](https://github.com/google/jax/issues)
|
||||
if you run into any errors or problems with the prebuilt wheels.
|
||||
|
||||
### Building JAX from source
|
||||
|
||||
First, obtain the JAX source code.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/google/jax
|
||||
cd jax
|
||||
```
|
||||
|
||||
You must also install some prerequisites:
|
||||
* a C++ compiler (g++ or clang)
|
||||
* Numpy
|
||||
* Scipy
|
||||
* Cython
|
||||
|
||||
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
|
||||
```
|
||||
sudo apt-get install g++ python python3-dev python3-numpy python3-scipy cython3
|
||||
```
|
||||
If you are building on a Mac, make sure XCode and the XCode command line tools
|
||||
are installed.
|
||||
|
||||
You can also install the necessary Python dependencies using `pip`:
|
||||
```
|
||||
pip install numpy scipy cython
|
||||
```
|
||||
|
||||
To build `jaxlib` with CUDA support, you can run
|
||||
|
||||
```bash
|
||||
python build/build.py --enable_cuda
|
||||
pip install -e build # installs jaxlib (includes XLA)
|
||||
pip install -e . # installs jax (pure Python)
|
||||
```
|
||||
|
||||
See `python build/build.py --help` for configuration options, including ways to
|
||||
specify the paths to CUDA and CUDNN, which you must have installed. The build
|
||||
also depends on NumPy, and a compiler toolchain corresponding to that of
|
||||
Ubuntu 16.04 or newer.
|
||||
|
||||
To build `jaxlib` without CUDA GPU support (CPU only), drop the `--enable_cuda`:
|
||||
|
||||
```bash
|
||||
python build/build.py
|
||||
pip install -e build # installs jaxlib (includes XLA)
|
||||
pip install -e . # installs jax
|
||||
```
|
||||
|
||||
To upgrade to the latest version from GitHub, just run `git pull` from the JAX
|
||||
repository root, and rebuild by running `build.py` if necessary. You shouldn't have
|
||||
to reinstall because `pip install -e` sets up symbolic links from site-packages
|
||||
into the repository.
|
||||
|
||||
## Running the tests
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user