2019-01-15 20:14:19 -05:00
JAX reference documentation
===============================
Composable transformations of Python+NumPy programs: differentiate, vectorize,
2019-06-04 10:09:43 -04:00
JIT to GPU/TPU, and more.
2019-01-15 20:14:19 -05:00
For an introduction to JAX, start at the
`JAX GitHub page <https://github.com/google/jax> `_ .
2019-09-30 11:00:02 -07:00
.. toctree ::
:maxdepth: 1
:caption: Tutorials
notebooks/quickstart
notebooks/autodiff_cookbook
2019-10-17 08:58:25 +02:00
Training a Simple Neural Network, with PyTorch Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb>
2019-10-03 11:20:04 +02:00
.. toctree ::
:maxdepth: 1
2019-10-17 08:58:25 +02:00
:caption: Advanced JAX Tutorials
2019-10-03 11:20:04 +02:00
2019-09-30 11:00:02 -07:00
notebooks/Common_Gotchas_in_JAX
2019-10-17 08:58:25 +02:00
notebooks/XLA_in_Python
2019-10-27 10:29:33 +01:00
notebooks/JAX_pytrees
2019-10-02 14:41:28 +02:00
notebooks/How_JAX_primitives_work
2019-10-28 12:54:04 -07:00
notebooks/Writing_custom_interpreters_in_Jax.ipynb
2019-10-17 08:58:25 +02:00
Training a Simple Neural Network, with Tensorflow Datasets Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb>
notebooks/maml
2019-10-21 23:24:16 +02:00
notebooks/score_matching
2019-10-17 08:58:25 +02:00
notebooks/vmapped_log_probs
2019-09-30 11:00:02 -07:00
2019-06-04 10:09:43 -04:00
.. toctree ::
:maxdepth: 1
:caption: Notes
2020-02-23 19:18:06 +01:00
CHANGELOG
2020-02-10 11:40:05 +01:00
jaxpr
2019-06-04 10:09:43 -04:00
async_dispatch
2019-07-23 09:53:27 -04:00
concurrency
2019-07-29 12:24:58 -07:00
gpu_memory_allocation
2019-08-08 21:02:41 -04:00
profiling
2019-08-25 14:28:53 -07:00
rank_promotion_warning
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
type_promotion
2019-06-04 10:09:43 -04:00
2019-10-03 11:56:57 +02:00
.. toctree ::
:maxdepth: 2
:caption: Developer documentation
developer
2020-02-10 11:40:05 +01:00
jax_internal_api
2019-10-03 11:56:57 +02:00
2019-01-15 20:14:19 -05:00
.. toctree ::
:maxdepth: 3
2019-09-30 11:00:02 -07:00
:caption: API documentation
2019-01-15 20:14:19 -05:00
jax
Indices and tables
==================
* :ref: `genindex`
* :ref: `modindex`
* :ref: `search`