readme: JAX is about composable transformations

This commit is contained in:
Matthew Johnson 2018-12-12 19:01:32 -08:00 committed by GitHub
parent 45444f02ff
commit 3307553dbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,8 +13,8 @@ JAX can automatically differentiate native
Python and NumPy functions. It can differentiate through loops, branches,
recursion, and closures, and it can take derivatives of derivatives of
derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
as well as forward-mode differentiation, and the two can be composed arbitrarily
to any order.
via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
and the two can be composed arbitrarily to any order.
Whats new is that JAX uses
[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md)
@ -26,6 +26,12 @@ into XLA-optimized kernels using a one-function API,
composed arbitrarily, so you can express sophisticated algorithms and get
maximal performance without leaving Python.
Dig a little deeper, and you'll see that JAX is really an extensible system for
[composable transformations of functions](#transformations). Both
[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
are instances of such transformations. Another is [`vmap`](#auto-vectorization-with-vmap)
for automatic vectorization, with more to come.
This is a research project, not an official Google product. Expect bugs and
sharp edges. Please help by trying it out, [reporting
bugs](https://github.com/google/jax/issues), and letting us know what you