rocm_jax/docs/index.rst
Peter Hawkins d958f3007d
Change JAX type promotion to prefer inexact types. (#1815)
Change the JAX type promotion table to prefer inexact types during type promotion.

NumPy's type promotion rules tend to promote aggressively to float64, which isn't a very accelerator-friendly behavior when not all accelerators (e.g., TPUs) support 64-bit floating point types. Even on accelerators that support 64-bit floating point types (e.g., GPUs), promotion to a 64-bit type comes with a significant performance cost.

This change makes JAX type promotion between inexact and exact types closer to PyTorch's promotion semantics, which are a better fit for modern accelerators:
e.g.,

```
import numpy as onp
from jax import numpy as np

In [1]: onp.promote_types(onp.float32, onp.int32)   
Out[1]: dtype('float64')

In [2]: onp.promote_types(onp.float16, onp.int64)   
Out[2]: dtype('float64')

In [3]: np.promote_types(onp.float32, onp.int32)    
Out[3]: dtype('float32')

In [4]: np.promote_types(onp.float16, onp.int64)    
Out[4]: dtype('float16')
```

This change is in preparation for enabling x64 mode by default on all platforms.
2019-12-05 10:57:23 -05:00

63 lines
1.4 KiB
ReStructuredText

JAX reference documentation
===============================
Composable transformations of Python+NumPy programs: differentiate, vectorize,
JIT to GPU/TPU, and more.
For an introduction to JAX, start at the
`JAX GitHub page <https://github.com/google/jax>`_.
.. toctree::
:maxdepth: 1
:caption: Tutorials
notebooks/quickstart
notebooks/autodiff_cookbook
Training a Simple Neural Network, with PyTorch Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb>
.. toctree::
:maxdepth: 1
:caption: Advanced JAX Tutorials
notebooks/Common_Gotchas_in_JAX
notebooks/XLA_in_Python
notebooks/JAX_pytrees
notebooks/How_JAX_primitives_work
notebooks/Writing_custom_interpreters_in_Jax.ipynb
Training a Simple Neural Network, with Tensorflow Datasets Data Loading <https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb>
notebooks/maml
notebooks/score_matching
notebooks/vmapped_log_probs
.. toctree::
:maxdepth: 1
:caption: Notes
async_dispatch
concurrency
gpu_memory_allocation
profiling
rank_promotion_warning
type_promotion
.. toctree::
:maxdepth: 2
:caption: Developer documentation
developer
.. toctree::
:maxdepth: 3
:caption: API documentation
jax
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`