2019-01-15 20:14:19 -05:00
|
|
|
JAX reference documentation
|
2021-02-12 17:03:53 -08:00
|
|
|
===========================
|
2019-01-15 20:14:19 -05:00
|
|
|
|
|
|
|
Composable transformations of Python+NumPy programs: differentiate, vectorize,
|
2019-06-04 10:09:43 -04:00
|
|
|
JIT to GPU/TPU, and more.
|
2019-01-15 20:14:19 -05:00
|
|
|
|
|
|
|
For an introduction to JAX, start at the
|
|
|
|
`JAX GitHub page <https://github.com/google/jax>`_.
|
|
|
|
|
2019-09-30 11:00:02 -07:00
|
|
|
.. toctree::
|
|
|
|
:maxdepth: 1
|
|
|
|
:caption: Tutorials
|
|
|
|
|
|
|
|
notebooks/quickstart
|
2021-01-26 12:08:37 -08:00
|
|
|
notebooks/thinking_in_jax
|
2019-09-30 11:00:02 -07:00
|
|
|
notebooks/autodiff_cookbook
|
2020-01-15 15:00:38 -08:00
|
|
|
notebooks/vmapped_log_probs
|
2021-02-22 09:19:41 -08:00
|
|
|
notebooks/neural_network_with_tfds_data
|
2019-10-03 11:20:04 +02:00
|
|
|
|
|
|
|
.. toctree::
|
|
|
|
:maxdepth: 1
|
2019-10-17 08:58:25 +02:00
|
|
|
:caption: Advanced JAX Tutorials
|
2019-10-03 11:20:04 +02:00
|
|
|
|
2019-09-30 11:00:02 -07:00
|
|
|
notebooks/Common_Gotchas_in_JAX
|
2021-02-16 17:21:56 -08:00
|
|
|
notebooks/convolutions
|
2020-01-15 15:00:38 -08:00
|
|
|
notebooks/Custom_derivative_rules_for_Python_code
|
2019-10-02 14:41:28 +02:00
|
|
|
notebooks/How_JAX_primitives_work
|
2021-02-11 11:56:24 -08:00
|
|
|
notebooks/Writing_custom_interpreters_in_Jax
|
|
|
|
notebooks/Neural_Network_and_Data_Loading
|
|
|
|
notebooks/XLA_in_Python
|
|
|
|
notebooks/maml
|
|
|
|
notebooks/score_matching
|
2019-09-30 11:00:02 -07:00
|
|
|
|
2019-06-04 10:09:43 -04:00
|
|
|
.. toctree::
|
|
|
|
:maxdepth: 1
|
|
|
|
:caption: Notes
|
|
|
|
|
2020-02-23 19:18:06 +01:00
|
|
|
CHANGELOG
|
2020-03-19 14:55:16 +01:00
|
|
|
faq
|
2020-02-10 11:40:05 +01:00
|
|
|
jaxpr
|
2019-06-04 10:09:43 -04:00
|
|
|
async_dispatch
|
2019-07-23 09:53:27 -04:00
|
|
|
concurrency
|
2019-07-29 12:24:58 -07:00
|
|
|
gpu_memory_allocation
|
2019-08-08 21:02:41 -04:00
|
|
|
profiling
|
2020-06-26 17:09:09 -04:00
|
|
|
device_memory_profiling
|
2020-06-03 09:46:00 -07:00
|
|
|
pytrees
|
2019-08-25 14:28:53 -07:00
|
|
|
rank_promotion_warning
|
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.
NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.
This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,
```
import numpy as onp
from jax import numpy as np
In [1]: onp.promote_types(onp.float32, onp.int32)
Out[1]: dtype('float64')
In [2]: onp.promote_types(onp.float16, onp.int64)
Out[2]: dtype('float64')
In [3]: np.promote_types(onp.float32, onp.int32)
Out[3]: dtype('float32')
In [4]: np.promote_types(onp.float16, onp.int64)
Out[4]: dtype('float16')
```
This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00
|
|
|
type_promotion
|
2021-01-25 10:46:58 -08:00
|
|
|
custom_vjp_update
|
2019-06-04 10:09:43 -04:00
|
|
|
|
2019-10-03 11:56:57 +02:00
|
|
|
.. toctree::
|
|
|
|
:maxdepth: 2
|
|
|
|
:caption: Developer documentation
|
|
|
|
|
|
|
|
developer
|
2020-02-10 11:40:05 +01:00
|
|
|
jax_internal_api
|
2019-10-03 11:56:57 +02:00
|
|
|
|
2019-01-15 20:14:19 -05:00
|
|
|
.. toctree::
|
|
|
|
:maxdepth: 3
|
2019-09-30 11:00:02 -07:00
|
|
|
:caption: API documentation
|
2019-01-15 20:14:19 -05:00
|
|
|
|
|
|
|
jax
|
|
|
|
|
|
|
|
|
|
|
|
Indices and tables
|
|
|
|
==================
|
|
|
|
|
|
|
|
* :ref:`genindex`
|
|
|
|
* :ref:`modindex`
|
|
|
|
* :ref:`search`
|