mention cloud tpus in readme

This commit is contained in:
Matthew Johnson 2019-12-14 08:34:01 -08:00 committed by Matthew Johnson
parent 764f007f9a
commit 5c800367d1
2 changed files with 11 additions and 5 deletions

View File

@ -29,7 +29,9 @@ executed. But JAX also lets you just-in-time compile your own Python functions
into XLA-optimized kernels using a one-function API,
[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
composed arbitrarily, so you can express sophisticated algorithms and get
maximal performance without leaving Python.
maximal performance without leaving Python. You can even program multiple GPUs
or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and
differentiate through the whole thing.
Dig a little deeper, and you'll see that JAX is really an extensible system for
[composable function transformations](#transformations). Both
@ -37,7 +39,7 @@ Dig a little deeper, and you'll see that JAX is really an extensible system for
are instances of such transformations. Others are
[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
parallel programming, with more to come.
parallel programming of multiple accelerators, with more to come.
This is a research project, not an official Google product. Expect bugs and
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
@ -72,11 +74,15 @@ perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grad
* [Reference documentation](#reference-documentation)
## Quickstart: Colab in the Cloud
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)
And for a deeper dive into JAX:
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
Colabs](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
For a deeper dive into JAX:
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- See the [full list of

View File

@ -23,7 +23,7 @@ Solve the wave equation with `pmap`, and make cool movies! The spatial domain is
![](https://raw.githubusercontent.com/google/jax/master/cloud_tpu_colabs/images/wave_movie.gif)
### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/master/cloud_tpu_colabs/NeurIPS_2019_JAX_demo.ipynb)
An overview of JAX presented at the Program Transformations for ML workshop at NeurIPS 2019. Covers basic numpy usage, grad, jit, vmap, and pmap.
An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`.
## Performance notes