13 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
d623d04172 JEP 18137: Scope of JAX NumPy & SciPy Wrappers 2023-11-03 19:19:21 -07:00
Jake VanderPlas
e21286d6f0 JEP 9263: Typed keys & pluggable RNGs
Co-authored-by: Roy Frostig <frostig@google.com>
2023-09-20 11:26:13 -07:00
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Roy Frostig
ce840a9cd8 JEP: jax.extend, a module for extensions 2023-05-05 13:50:22 -07:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Jake VanderPlas
358363e17f JEP 12049: Type Annotation Roadmap 2022-09-13 09:14:48 -07:00
Roy Frostig
c6ab3a6a60 convert custom VJP update guide to a retroactive JEP 2022-08-10 13:45:57 -07:00
Matthew Johnson
be6f6bfe9f set new jax.remat / jax.checkpoint to be on-by-default 2022-08-10 10:29:38 -07:00
Roy Frostig
4c18f1a580 link to both closed and open enhancement proposals
PiperOrigin-RevId: 466212251
2022-08-08 18:48:38 -07:00
Peter Hawkins
71b29b1cc6 Create JAX Enhancement Proposals (JEPs).
Migrate existing design documents to JEPs.
2022-08-08 16:13:58 -04:00