rocm_jax/docs/jep/index.rst
Matthew Johnson 70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00

57 lines
2.4 KiB
ReStructuredText

JAX Enhancement Proposals (JEPs)
================================
Most changes can be discussed with simple issues/discussions and pull requests.
Some changes though are a bit larger in scope or require more discussion, and
these should be implemented as JEP. This allows for writing longer documents
that can be discussed in a pull request themselves.
The structure of JEPs is kept as lightweight as possible to start and might
be extended later on.
When you should use a JEP
-------------------------
- When your change requires a design doc. We prefer collecting the designs as
JEPs for better discoverability and further reference.
- When your change requires extensive discussion. It's fine to have relatively
short discussions on issues or pull requests, but when the discussion gets
longer this becomes unpractical for later digestion. JEPs allow to update the
main document with a summary of the discussion and these updates can be
discussed themselves in the pull request adding the JEP.
How to start a JEP
------------------
First, create an issue with the `JEP label`_. All pull requests that relate to
the JEP (i.e. adding the JEP itself as well as any implementing pull requests)
should be linked to this issue.
Then create a pull request that adds a file named
`%d-{short-title}.md` - with the number being the issue number.
.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP
.. toctree::
:maxdepth: 1
263: JAX PRNG Design <263-prng>
2026: Custom JVP/VJP rules for JAX-transformable functions <2026-custom-derivatives>
4008: Custom VJP and `nondiff_argnums` update <4008-custom-vjp-update>
4410: Omnistaging <4410-omnistaging>
9407: Design of Type Promotion Semantics for JAX <9407-type-promotion>
9419: Jax and Jaxlib versioning <9419-jax-versioning>
10657: Sequencing side-effects in JAX <10657-sequencing-effects>
11830: `jax.remat` / `jax.checkpoint` new implementation <11830-new-remat-checkpoint>
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
15856: `jax.extend`, an extensions module <15856-jex>
17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose>
Several early JEPs were converted in hindsight from other documentation,
issues, and pull requests, so they might not exactly reflect the process
outlined above.