mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
mention cloud tpus in readme
This commit is contained in:
parent
764f007f9a
commit
5c800367d1
14
README.md
14
README.md
@ -29,7 +29,9 @@ executed. But JAX also lets you just-in-time compile your own Python functions
|
||||
into XLA-optimized kernels using a one-function API,
|
||||
[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
|
||||
composed arbitrarily, so you can express sophisticated algorithms and get
|
||||
maximal performance without leaving Python.
|
||||
maximal performance without leaving Python. You can even program multiple GPUs
|
||||
or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and
|
||||
differentiate through the whole thing.
|
||||
|
||||
Dig a little deeper, and you'll see that JAX is really an extensible system for
|
||||
[composable function transformations](#transformations). Both
|
||||
@ -37,7 +39,7 @@ Dig a little deeper, and you'll see that JAX is really an extensible system for
|
||||
are instances of such transformations. Others are
|
||||
[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
|
||||
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
|
||||
parallel programming, with more to come.
|
||||
parallel programming of multiple accelerators, with more to come.
|
||||
|
||||
This is a research project, not an official Google product. Expect bugs and
|
||||
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
|
||||
@ -72,11 +74,15 @@ perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grad
|
||||
* [Reference documentation](#reference-documentation)
|
||||
|
||||
## Quickstart: Colab in the Cloud
|
||||
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
|
||||
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
|
||||
Here are some starter notebooks:
|
||||
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
|
||||
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)
|
||||
|
||||
And for a deeper dive into JAX:
|
||||
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
|
||||
Colabs](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
|
||||
|
||||
For a deeper dive into JAX:
|
||||
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
|
||||
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
|
||||
- See the [full list of
|
||||
|
@ -23,7 +23,7 @@ Solve the wave equation with `pmap`, and make cool movies! The spatial domain is
|
||||

|
||||
|
||||
### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/master/cloud_tpu_colabs/NeurIPS_2019_JAX_demo.ipynb)
|
||||
An overview of JAX presented at the Program Transformations for ML workshop at NeurIPS 2019. Covers basic numpy usage, grad, jit, vmap, and pmap.
|
||||
An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`.
|
||||
|
||||
## Performance notes
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user