Merge pull request #11799 from hawkinsp:jep

PiperOrigin-RevId: 466178951
This commit is contained in:
jax authors 2022-08-08 16:16:18 -07:00
commit 38ab3d88ae
15 changed files with 58 additions and 21 deletions

View File

@ -117,7 +117,7 @@ exclude_patterns = [
# Ignore markdown source for notebooks; myst-nb builds from the ipynb
# These are kept in sync using the jupytext pre-commit hook.
'notebooks/*.md',
'design_notes/type_promotion.md',
'jep/9407-type-promotion.md',
# TODO: revert to jax-101/*.md once 08-pjit has a notebook
'jax-101/01-jax-basics.md',
'jax-101/02-jitting.md',
@ -202,7 +202,7 @@ nb_execution_excludepatterns = [
# Strange error apparently due to asynchronous cell execution
'notebooks/thinking_in_jax.*',
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
'design_notes/type_promotion.*',
'jep/9407-type-promotion.*',
# TODO(jakevdp): enable execution on the following if possible:
'jax-101/*',
'notebooks/xmap_tutorial.*',

View File

@ -1,12 +0,0 @@
Design Notes
============
.. toctree::
:maxdepth: 1
custom_derivatives
jax_versioning
omnistaging
prng
type_promotion
sequencing_effects

View File

@ -74,7 +74,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
developer
jax_internal_api
autodidax
design_notes/index
jep/index
.. toctree::
:maxdepth: 3

View File

@ -8,7 +8,7 @@
"source": [
"# Design of Type Promotion Semantics for JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/design_notes/type_promotion.ipynb)\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
"\n",
"*Jake VanderPlas, December 2021*\n",
"\n",

View File

@ -16,7 +16,7 @@ kernelspec:
# Design of Type Promotion Semantics for JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/design_notes/type_promotion.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
*Jake VanderPlas, December 2021*

49
docs/jep/index.rst Normal file
View File

@ -0,0 +1,49 @@
JAX Enhancement Proposals (JEPs)
================================
Most changes can be discussed with simple issues/discussions and pull requests.
Some changes though are a bit larger in scope or require more discussion, and
these should be implemented as JEP. This allows for writing longer documents
that can be discussed in a pull request themselves.
The structure of JEPs is kept as lightweight as possible to start and might
be extended later on.
When you should use a JEP
-------------------------
- When your change requires a design doc. We prefer collecting the designs as
JEPs for better discoverability and further reference.
- When your change requires extensive discussion. It's fine to have relatively
short discussions on issues or pull requests, but when the discussion gets
longer this becomes unpractical for later digestion. JEPs allow to update the
main document with a summary of the discussion and these updates can be
discussed themselves in the pull request adding the JEP.
How to start a JEP
------------------
First, create an issue with the `JEP label`_. All pull requests that relate to
the JEP (i.e. adding the JEP itself as well as any implementing pull requests)
should be linked to this issue.
Then create a pull request that adds a file named
`%d-{short-title}.md` - with the number being the issue number.
.. _JEP label: https://github.com/google/jax/labels/JEP
.. toctree::
:maxdepth: 1
263: JAX PRNG Design <263-prng>
2026: Custom JVP/VJP rules for JAX-transformable functions <2026-custom-derivatives>
4410: Omnistaging <4410-omnistaging>
9407: Design of Type Promotion Semantics for JAX <9407-type-promotion>
9419: Jax and Jaxlib versioning <9419-jax-versioning>
10657: Sequencing side-effects in JAX <10657-sequencing-effects>

View File

@ -993,7 +993,7 @@
"id": "COjzGBpO4tzL"
},
"source": [
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by two unsigned-int32s that we call a __key__:"
]

View File

@ -504,7 +504,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
+++ {"id": "COjzGBpO4tzL"}
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by two unsigned-int32s that we call a __key__:

View File

@ -4,7 +4,7 @@ Type promotion semantics
========================
This document describes JAX's type promotion rulesi.e., the result of :func:`jax.numpy.promote_types` for each pair of types.
For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX <https://jax.readthedocs.io/en/latest/design_notes/type_promotion.html>`_.
For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX <https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html>`_.
JAX's type promotion behavior is determined via the following type promotion lattice:

View File

@ -66,7 +66,7 @@ Design and Context
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
See `docs/design_notes/prng.md <https://github.com/google/jax/blob/main/docs/design_notes/prng.md>`_
See `docs/jep/263-prng.md <https://github.com/google/jax/blob/main/docs/jep/263-prng.md>`_
for more details.
To summarize, among other requirements, the JAX PRNG aims to: