45 Commits

Author SHA1 Message Date
Yash Katariya
ca2d1584f8 Remove mesh_utils.create_device_mesh from docs
PiperOrigin-RevId: 687695419
2024-10-19 15:48:42 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
09fd345de9 pre-commit: update hooks & pin using hashes 2024-08-27 15:23:13 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
George Necula
0223e830c1 [docs] Enable doctest for .md documentation files
The current invocation of doctest from GH actions picked up only .rst files.
We enable .md files also, and we make a few changes to ensure that the
doctest passes on the existing files.

The changes fall into several categories:
  * add a newline before the end of the code block, for doctest to
    pick up the expected output properly
  * update the expected values to match the current behavior
  * disable some doctests that raise expected exceptions, whenever
    I could not get doctest to match the exception details.
    Sometimes +IGNORE_EXCEPTION_DETAIL was enough, and other times
    I had to use +SKIP.
2024-06-03 11:27:34 +03:00
Sai-Suraj-27
29def4eefa Updated all the pre-commit hooks versions. 2024-04-08 00:59:02 +05:30
rajasekharporeddy
a81961c0cf Change vmap argument from axis to in_axes for Inadvertent transposing of key buffers in 9263-typed-keys.md 2024-03-28 15:10:52 +05:30
Jake VanderPlas
d6c07bdf51 DOC: read-through and edit the new jax tutorials 2024-03-20 18:18:08 -07:00
rajasekharporeddy
e32bac4b3a Fixed Typos in JEP doc files 2024-03-12 09:58:13 +05:30
rajasekharporeddy
61c64c10f8 Fixed Several Typos
Fixed Typos in JEP doc files

Revert "Fixed Typos in JEP doc files"

This reverts commit c2a16950e0fc1b32971168501d183991e2394b5d.

revert two changes

reverted one change in advanced-autodiff

revert one change in parallelism

sync notebooks
2024-03-12 00:37:46 +05:30
Jake VanderPlas
1947104212 JEP 9263: update deprecation status of KeyArray and PRNGKeyArray 2024-01-16 10:33:14 -08:00
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Jake VanderPlas
10eae3f93a CI: update jupytext version 2024-01-09 14:34:21 -08:00
Jan Hrček
4da56dcdd7 Fix duplicate word occurrences 2023-12-19 06:15:30 +01:00
Jake VanderPlas
cacadf43c0 JAX tutorials: pseudorandom numbers 2023-11-21 14:26:06 -08:00
Jake VanderPlas
d623d04172 JEP 18137: Scope of JAX NumPy & SciPy Wrappers 2023-11-03 19:19:21 -07:00
Jake VanderPlas
389eb97a7c CI: update pre-commit hooks to latest version 2023-10-30 09:12:24 -07:00
Jake VanderPlas
0da4be5e2a [random] make PRNG impl attributes private 2023-10-18 11:10:47 -07:00
parikshit adhikari
e21409fdeb fix: typo inside docs/jep/12049-type-annotations.md 2023-10-13 08:51:01 +05:45
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Jake VanderPlas
e21286d6f0 JEP 9263: Typed keys & pluggable RNGs
Co-authored-by: Roy Frostig <frostig@google.com>
2023-09-20 11:26:13 -07:00
Matthew Johnson
70b58bbd30 rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
2023-08-31 17:31:21 -07:00
Matthew Johnson
8a04dfd830 rolling back shard_map transposition change to fix a bug
Reverts 437d7be73534403f39fbee9d6391be1c532933a1

PiperOrigin-RevId: 561730581
2023-08-31 12:39:56 -07:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00
Matthew Johnson
fdd252f6ca [shard-map] add rewrite for efficient transposition 2023-08-30 15:08:11 -07:00
Jake VanderPlas
7bb8312f82 CI: update jupytext to v0.14.7 2023-07-24 11:51:45 -07:00
Roy Frostig
ce840a9cd8 JEP: jax.extend, a module for extensions 2023-05-05 13:50:22 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
71f120beed Add "Open in Kaggle" buttons to Jupyter notebooks. 2023-03-01 13:15:42 -05:00
Frederic Bastien
93c93133ea Use right fct name. 2023-02-14 11:21:16 -08:00
Frederic Bastien
d2bb1e089d Be consistent in the index used 2023-02-14 11:21:03 -08:00
Frederic Bastien
673510202d Small crash fixes 2023-02-14 11:14:26 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
Jake VanderPlas
20b55a119e CI: update jupytext version 2023-01-23 14:42:03 -08:00
Ikko Ashimine
28def736d1
Fix typo in 9419-jax-versioning.md
overriden -> overridden
2022-10-25 03:26:48 +09:00
Jake VanderPlas
fce1099997 Update JEP-12049 implementation discussion 2022-09-20 09:44:29 -07:00
Jake VanderPlas
5829c6ae9d Change case of typing.Dtype -> typing.DType
This follows the convention used in numpy.typing.DType.
2022-09-14 15:03:55 -07:00
Jake VanderPlas
358363e17f JEP 12049: Type Annotation Roadmap 2022-09-13 09:14:48 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
Jake VanderPlas
eeb9b5f1f6 pre-commit hook: update flake8, mypy, & jupytext 2022-08-15 15:32:45 -07:00
Roy Frostig
c6ab3a6a60 convert custom VJP update guide to a retroactive JEP 2022-08-10 13:45:57 -07:00
Matthew Johnson
be6f6bfe9f set new jax.remat / jax.checkpoint to be on-by-default 2022-08-10 10:29:38 -07:00
Roy Frostig
4c18f1a580 link to both closed and open enhancement proposals
PiperOrigin-RevId: 466212251
2022-08-08 18:48:38 -07:00
Peter Hawkins
71b29b1cc6 Create JAX Enhancement Proposals (JEPs).
Migrate existing design documents to JEPs.
2022-08-08 16:13:58 -04:00