mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
update readme, link gotchas notebook
This commit is contained in:
parent
9711202c18
commit
a848d0c0fe
26
README.md
26
README.md
@ -33,7 +33,8 @@ are instances of such transformations. Another is [`vmap`](#auto-vectorization-w
|
||||
for automatic vectorization, with more to come.
|
||||
|
||||
This is a research project, not an official Google product. Expect bugs and
|
||||
sharp edges. Please help by trying it out, [reporting
|
||||
[sharp edges](https://github.com/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb).
|
||||
Please help by trying it out, [reporting
|
||||
bugs](https://github.com/google/jax/issues), and letting us know what you
|
||||
think!
|
||||
|
||||
@ -78,13 +79,16 @@ open](https://github.com/google/jax) by a growing number of
|
||||
* [Current gotchas](#current-gotchas)
|
||||
|
||||
## Quickstart: Colab in the Cloud
|
||||
Jump right in using a notebook in your browser, connected to a Google Cloud GPU:
|
||||
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
|
||||
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://colab.research.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb)
|
||||
- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/notebooks/neural_network_and_data_loading.ipynb)
|
||||
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/notebooks/neural_network_with_tfds_data.ipynb)
|
||||
- [The Autodiff Cookbook: easy and powerful automatic differentiation in JAX](https://colab.research.google.com/github/google/jax/blob/master/notebooks/autodiff_cookbook.ipynb)
|
||||
- [MAML Tutorial with JAX](https://colab.research.google.com/github/google/jax/blob/autodiff-cookbook/notebooks/maml.ipynb).
|
||||
|
||||
And for a deeper dive into JAX:
|
||||
- [Common gotchas and sharp edges](https://github.com/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb)
|
||||
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://colab.research.google.com/github/google/jax/blob/master/notebooks/autodiff_cookbook.ipynb)
|
||||
- [Directly using XLA in Python](https://colab.research.google.com/github/google/jax/blob/autodiff-cookbook/notebooks/XLA_in_Python.ipynb)
|
||||
- [MAML Tutorial with JAX](https://colab.research.google.com/github/google/jax/blob/autodiff-cookbook/notebooks/maml.ipynb).
|
||||
|
||||
## Installation
|
||||
JAX is written in pure Python, but it depends on XLA, which needs to be
|
||||
@ -708,15 +712,17 @@ code to compile and end-to-end optimize much bigger functions.
|
||||
|
||||
## Current gotchas
|
||||
|
||||
Some things we don't handle that might surprise NumPy users:
|
||||
1. No in-place mutation syntax. JAX requires functional code. You can use
|
||||
lax.dynamic\_update\_slice for slice updating that, under `@jit`, will be
|
||||
optimized to in-place buffer updating.
|
||||
2. PRNGs can be awkward, and linearity is not checked with a warning.
|
||||
For a survey of current gotchas, with examples and explanations, we highly recommend reading the [Gotchas Notebook](https://github.com/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb).
|
||||
|
||||
Two stand-out gotchas that might surprise NumPy users:
|
||||
1. In-place mutation of arrays isn't supported. Generally JAX requires functional code.
|
||||
2. PRNGs can be awkward, and non-reuse (linearity) is not checked.
|
||||
|
||||
See [the notebook](https://github.com/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb) for much more information.
|
||||
|
||||
## Contributors
|
||||
|
||||
So far, JAX includes lots of help and contributions from
|
||||
So far, JAX includes lots of help and [contributions](https://github.com/google/jax/graphs/contributors). In addition to the code contributions reflected on GitHub, JAX has benefitted substantially from the advice of
|
||||
[Jamie Townsend](https://github.com/j-towns),
|
||||
[Peter Hawkins](https://github.com/hawkinsp),
|
||||
[Jonathan Ragan-Kelley](https://people.eecs.berkeley.edu/~jrk/),
|
||||
|
Loading…
x
Reference in New Issue
Block a user