update neural network libraries text

Co-authored-by: George Necula <necula@google.com>
This commit is contained in:
Matthew Johnson 2020-10-30 08:42:04 -07:00
parent a4ce2813c8
commit a7d1963bc8

View File

@ -447,9 +447,21 @@ source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source
## Neural network libraries
There are lots of great deep learning libraries built on top of JAX!
Multiple Google research groups develop and share libraries for training neural
networks in JAX. If you want a fully featured library for neural network
training with examples and how-to guides, try
[Flax](https://github.com/google/flax). Another option is
[Trax](https://github.com/google/trax), a combinator-based framework focused on
ease-of-use and end-to-end single-command examples, especially for sequence
models and reinforcement learning. Finally,
[Objax](https://github.com/google/objax) is a minimalist object-oriented
framework with a PyTorch-like interface.
If you want a batteries-included library for neural network training, try [Flax](https://github.com/google/flax) by a Google Brain team in Amsterdam. Another good option is DeepMind's [Haiku](https://github.com/deepmind/dm-haiku), a JAX version of Sonnet focused solely on neural network layers. Finally, [Trax](https://github.com/google/trax) by a Brain team in California is a configuration-driven framework focused on sequence model research and reinforcement learning as a successor to Tensor2Tensor.
DeepMind has open-sourced an ecosystem of libraries around JAX including
[Haiku](https://github.com/deepmind/dm-haiku) for neural network modules,
[Optax](https://github.com/deepmind/optax) for gradient processing and
optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and
[chex](https://github.com/deepmind/chex) for reliable code and testing.
## Citing JAX
@ -460,7 +472,7 @@ To cite this repository:
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and Skye Wanderman-Milne},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.1.55},
version = {0.2.5},
year = {2018},
}
```