mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 04:16:07 +00:00

The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests! The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those! The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule. Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa PiperOrigin-RevId: 561805840
57 lines
2.4 KiB
ReStructuredText
57 lines
2.4 KiB
ReStructuredText
JAX Enhancement Proposals (JEPs)
|
|
================================
|
|
|
|
Most changes can be discussed with simple issues/discussions and pull requests.
|
|
|
|
Some changes though are a bit larger in scope or require more discussion, and
|
|
these should be implemented as JEP. This allows for writing longer documents
|
|
that can be discussed in a pull request themselves.
|
|
|
|
The structure of JEPs is kept as lightweight as possible to start and might
|
|
be extended later on.
|
|
|
|
When you should use a JEP
|
|
-------------------------
|
|
|
|
- When your change requires a design doc. We prefer collecting the designs as
|
|
JEPs for better discoverability and further reference.
|
|
|
|
- When your change requires extensive discussion. It's fine to have relatively
|
|
short discussions on issues or pull requests, but when the discussion gets
|
|
longer this becomes unpractical for later digestion. JEPs allow to update the
|
|
main document with a summary of the discussion and these updates can be
|
|
discussed themselves in the pull request adding the JEP.
|
|
|
|
How to start a JEP
|
|
------------------
|
|
|
|
First, create an issue with the `JEP label`_. All pull requests that relate to
|
|
the JEP (i.e. adding the JEP itself as well as any implementing pull requests)
|
|
should be linked to this issue.
|
|
|
|
Then create a pull request that adds a file named
|
|
`%d-{short-title}.md` - with the number being the issue number.
|
|
|
|
.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP
|
|
|
|
.. toctree::
|
|
:maxdepth: 1
|
|
|
|
263: JAX PRNG Design <263-prng>
|
|
2026: Custom JVP/VJP rules for JAX-transformable functions <2026-custom-derivatives>
|
|
4008: Custom VJP and `nondiff_argnums` update <4008-custom-vjp-update>
|
|
4410: Omnistaging <4410-omnistaging>
|
|
9407: Design of Type Promotion Semantics for JAX <9407-type-promotion>
|
|
9419: Jax and Jaxlib versioning <9419-jax-versioning>
|
|
10657: Sequencing side-effects in JAX <10657-sequencing-effects>
|
|
11830: `jax.remat` / `jax.checkpoint` new implementation <11830-new-remat-checkpoint>
|
|
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
|
|
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
|
|
15856: `jax.extend`, an extensions module <15856-jex>
|
|
17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose>
|
|
|
|
|
|
Several early JEPs were converted in hindsight from other documentation,
|
|
issues, and pull requests, so they might not exactly reflect the process
|
|
outlined above.
|