mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11799 from hawkinsp:jep
PiperOrigin-RevId: 466178951
This commit is contained in:
commit
38ab3d88ae
@ -117,7 +117,7 @@ exclude_patterns = [
|
||||
# Ignore markdown source for notebooks; myst-nb builds from the ipynb
|
||||
# These are kept in sync using the jupytext pre-commit hook.
|
||||
'notebooks/*.md',
|
||||
'design_notes/type_promotion.md',
|
||||
'jep/9407-type-promotion.md',
|
||||
# TODO: revert to jax-101/*.md once 08-pjit has a notebook
|
||||
'jax-101/01-jax-basics.md',
|
||||
'jax-101/02-jitting.md',
|
||||
@ -202,7 +202,7 @@ nb_execution_excludepatterns = [
|
||||
# Strange error apparently due to asynchronous cell execution
|
||||
'notebooks/thinking_in_jax.*',
|
||||
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
|
||||
'design_notes/type_promotion.*',
|
||||
'jep/9407-type-promotion.*',
|
||||
# TODO(jakevdp): enable execution on the following if possible:
|
||||
'jax-101/*',
|
||||
'notebooks/xmap_tutorial.*',
|
||||
|
@ -1,12 +0,0 @@
|
||||
Design Notes
|
||||
============
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
custom_derivatives
|
||||
jax_versioning
|
||||
omnistaging
|
||||
prng
|
||||
type_promotion
|
||||
sequencing_effects
|
@ -74,7 +74,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
developer
|
||||
jax_internal_api
|
||||
autodidax
|
||||
design_notes/index
|
||||
jep/index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
@ -8,7 +8,7 @@
|
||||
"source": [
|
||||
"# Design of Type Promotion Semantics for JAX\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/design_notes/type_promotion.ipynb)\n",
|
||||
"[](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
|
||||
"\n",
|
||||
"*Jake VanderPlas, December 2021*\n",
|
||||
"\n",
|
@ -16,7 +16,7 @@ kernelspec:
|
||||
|
||||
# Design of Type Promotion Semantics for JAX
|
||||
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/design_notes/type_promotion.ipynb)
|
||||
[](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
|
||||
|
||||
*Jake VanderPlas, December 2021*
|
||||
|
49
docs/jep/index.rst
Normal file
49
docs/jep/index.rst
Normal file
@ -0,0 +1,49 @@
|
||||
JAX Enhancement Proposals (JEPs)
|
||||
================================
|
||||
|
||||
Most changes can be discussed with simple issues/discussions and pull requests.
|
||||
|
||||
Some changes though are a bit larger in scope or require more discussion, and
|
||||
these should be implemented as JEP. This allows for writing longer documents
|
||||
that can be discussed in a pull request themselves.
|
||||
|
||||
The structure of JEPs is kept as lightweight as possible to start and might
|
||||
be extended later on.
|
||||
|
||||
When you should use a JEP
|
||||
-------------------------
|
||||
|
||||
- When your change requires a design doc. We prefer collecting the designs as
|
||||
JEPs for better discoverability and further reference.
|
||||
|
||||
- When your change requires extensive discussion. It's fine to have relatively
|
||||
short discussions on issues or pull requests, but when the discussion gets
|
||||
longer this becomes unpractical for later digestion. JEPs allow to update the
|
||||
main document with a summary of the discussion and these updates can be
|
||||
discussed themselves in the pull request adding the JEP.
|
||||
|
||||
How to start a JEP
|
||||
------------------
|
||||
|
||||
First, create an issue with the `JEP label`_. All pull requests that relate to
|
||||
the JEP (i.e. adding the JEP itself as well as any implementing pull requests)
|
||||
should be linked to this issue.
|
||||
|
||||
Then create a pull request that adds a file named
|
||||
`%d-{short-title}.md` - with the number being the issue number.
|
||||
|
||||
.. _JEP label: https://github.com/google/jax/labels/JEP
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
263: JAX PRNG Design <263-prng>
|
||||
2026: Custom JVP/VJP rules for JAX-transformable functions <2026-custom-derivatives>
|
||||
4410: Omnistaging <4410-omnistaging>
|
||||
9407: Design of Type Promotion Semantics for JAX <9407-type-promotion>
|
||||
9419: Jax and Jaxlib versioning <9419-jax-versioning>
|
||||
10657: Sequencing side-effects in JAX <10657-sequencing-effects>
|
||||
|
||||
|
||||
|
||||
|
@ -993,7 +993,7 @@
|
||||
"id": "COjzGBpO4tzL"
|
||||
},
|
||||
"source": [
|
||||
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
|
||||
"JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
|
||||
"\n",
|
||||
"The random state is described by two unsigned-int32s that we call a __key__:"
|
||||
]
|
||||
|
@ -504,7 +504,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
|
||||
|
||||
+++ {"id": "COjzGBpO4tzL"}
|
||||
|
||||
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
|
||||
JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
|
||||
|
||||
The random state is described by two unsigned-int32s that we call a __key__:
|
||||
|
||||
|
@ -4,7 +4,7 @@ Type promotion semantics
|
||||
========================
|
||||
|
||||
This document describes JAX's type promotion rules–i.e., the result of :func:`jax.numpy.promote_types` for each pair of types.
|
||||
For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX <https://jax.readthedocs.io/en/latest/design_notes/type_promotion.html>`_.
|
||||
For some background on the considerations that went into the design of what is described below, see `Design of Type Promotion Semantics for JAX <https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html>`_.
|
||||
|
||||
JAX's type promotion behavior is determined via the following type promotion lattice:
|
||||
|
||||
|
@ -66,7 +66,7 @@ Design and Context
|
||||
**TLDR**: JAX PRNG = `Threefry counter PRNG <http://www.thesalmons.org/john/random123/papers/random123sc11.pdf>`_
|
||||
+ a functional array-oriented `splitting model <https://dl.acm.org/citation.cfm?id=2503784>`_
|
||||
|
||||
See `docs/design_notes/prng.md <https://github.com/google/jax/blob/main/docs/design_notes/prng.md>`_
|
||||
See `docs/jep/263-prng.md <https://github.com/google/jax/blob/main/docs/jep/263-prng.md>`_
|
||||
for more details.
|
||||
|
||||
To summarize, among other requirements, the JAX PRNG aims to:
|
||||
|
Loading…
x
Reference in New Issue
Block a user