22388 Commits

Author SHA1 Message Date
Roy Frostig
b8f8b7b07f docs: sentence case page titles, section headings, some content 2024-08-12 18:12:17 -07:00
Roy Frostig
2644299f7e docs: sentence case index and sub-index headings
We currently use both forms, so for consistency (and easier reading),
pick this one.
2024-08-12 13:52:43 -07:00
jax authors
9e86416a32 Update XLA dependency to use revision
dfb2a8b498.

PiperOrigin-RevId: 661668331
2024-08-10 14:20:14 -07:00
Yash Katariya
c08656c61d [Rollback] We still want to allow multiple meshes in the user program
Reverts dd958adc39550d2758ecdb13809c6d85df7658a2

PiperOrigin-RevId: 661537233
2024-08-09 23:17:46 -07:00
Yash Katariya
abc9ba00e9 Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
jax authors
7a75c96aa9 Update XLA dependency to use revision
46e205a0b6.

PiperOrigin-RevId: 661412627
2024-08-09 14:53:42 -07:00
Parker Schuh
4863a568f9 Fix array_test.py when jax_pmap_no_rank_reduction is flipped to true.
The problem is that squeezing was happening on noncommitted arrays
so list(x) was moving all the shards to device 0. This will potentially
cause ooms.

PiperOrigin-RevId: 661408226
2024-08-09 14:40:52 -07:00
Jieying Luo
a3ae5e18d3 Remove build_cuda_plugin_from_source flag which is no longe used.
751b5742fd

PiperOrigin-RevId: 661370449
2024-08-09 12:54:14 -07:00
jax authors
35c2454ea6 Merge pull request #22887 from justinjfu:pallas_distr_docs
PiperOrigin-RevId: 661368285
2024-08-09 12:47:48 -07:00
jax authors
3bd3597703 Improves error message in case of invalid sharding mesh
PiperOrigin-RevId: 661358450
2024-08-09 12:18:16 -07:00
jax authors
aa334145b4 Merge pull request #22958 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 661340981
2024-08-09 11:30:16 -07:00
jax authors
c207ad4c04 Merge pull request #22960 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 661319343
2024-08-09 10:37:44 -07:00
rajasekharporeddy
ff1f199d09 Improved docs for jnp.fft.rfftn and jnp.fft.irfftn 2024-08-09 23:07:17 +05:30
Justin Fu
47842b358d
Merge pull request #22887 from justinjfu/pallas_distr_docs
[Pallas] Add pallas distributed computation tutorial
2024-08-09 09:56:07 -07:00
Tom Hennigan
5ced6db692 Cache _get_tpu_generation to avoid repeated calls to jax.devices().
We use `util.cache` such that if the default backend changes this function
will (correctly) be re-evaluated.

PiperOrigin-RevId: 661293560
2024-08-09 09:33:45 -07:00
jax authors
40e67c73ee Merge pull request #22373 from gspschmid:gschmid/jax-local-device-ids-from-env
PiperOrigin-RevId: 661291774
2024-08-09 09:28:39 -07:00
Tomás Longeri
77afe251e7 [Mosaic TPU][Python] Check validity of VectorLayout on init
PiperOrigin-RevId: 661226283
2024-08-09 05:28:00 -07:00
Georg Stefan Schmid
f9bc4c643b [jax.distributed] Allow setting local device ids via env var 2024-08-09 10:23:17 +00:00
rajasekharporeddy
6ee1555d21 Fix broken links in jnp.fft.fftfreq and jnp.fft.rfftfreq 2024-08-09 14:23:56 +05:30
Tomás Longeri
e57a7e3f05 [Mosaic] Column shift relayouts for non-native tilings and packed types, except for (1, n) and packed
PiperOrigin-RevId: 661091012
2024-08-08 20:14:08 -07:00
Justin Fu
bdb03309b5
Merge branch 'main' into pallas_distr_docs 2024-08-08 17:36:36 -07:00
Justin Fu
dcd186f552 [Pallas] Add pallas distributed computation tutorial 2024-08-08 17:34:35 -07:00
jax authors
f2068bb4ad Update XLA dependency to use revision
76978f280d.

PiperOrigin-RevId: 660999885
2024-08-08 15:18:44 -07:00
Jieying Luo
9c2caedab1 Add subdirectories to the output path when building editable wheels for jaxlib and GPU plugin.
When `build_gpu_plugin` is true, three wheels will be produced (jaxlib, jax-cuda-pjrt and jax-cuda-plugin). If they are editable, they need to be placed in subdirectories to avoid overwrite.

Tested on GPU. After the editable wheels are built, they can be installed with `pip install -e /jax/dist/jax_gpu_pjrt /jax/dist/jaxlib /jax/dist/jax_gpu_plugin`.

PiperOrigin-RevId: 660984311
2024-08-08 14:34:55 -07:00
Justin Fu
deefbdd626 Temporarily disable broken tests in tpu_pallas_pipeline_test.py
PiperOrigin-RevId: 660972804
2024-08-08 14:04:04 -07:00
Sergei Lebedev
12a9c8cfd4 Pallas Mosaic GPU lowering now supports (at least the basic) pl.BlockSpecs
Note that we still don't do any pipelining whatsoever, but it can be done once
this change lands.

PiperOrigin-RevId: 660969393
2024-08-08 13:55:07 -07:00
Gleb Pobudzey
d28d14917e Fix error message in dot_product_attention
PiperOrigin-RevId: 660960409
2024-08-08 13:30:21 -07:00
Sergei Lebedev
d8eafc8ee3 Disabled nn_test under asan on TPU as well, since it also times out
PiperOrigin-RevId: 660950262
2024-08-08 13:02:31 -07:00
Dan Foreman-Mackey
efb7721671 Remove unnecessary constraint on keyword-only arguments in custom_vjp with optimize_remat=True.
PiperOrigin-RevId: 660945559
2024-08-08 12:49:27 -07:00
jax authors
93d4629846 Merge pull request #22903 from jakevdp:update-array-api
PiperOrigin-RevId: 660941835
2024-08-08 12:39:56 -07:00
Jake VanderPlas
d999208863 [array API] update test suite to most recent commit 2024-08-08 12:33:30 -07:00
Jieying Luo
751b5742fd Deprecate using build_cuda_plugin_from_source flag and rely on jaxlib_build config.
If jaxlib needs to be built from source, cuda plugin will be built from source as well.

PiperOrigin-RevId: 660926791
2024-08-08 11:58:13 -07:00
Yash Katariya
e6303244bf If the memory kind is the default kind throughout the jaxpr, then revert back to the previous device_put behavior which was a no-op inside jit.
This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind.

PiperOrigin-RevId: 660913387
2024-08-08 11:24:25 -07:00
jax authors
bdd8f74efe Merge pull request #22916 from jakevdp:piecewise-doc
PiperOrigin-RevId: 660896267
2024-08-08 10:45:09 -07:00
jax authors
0309adf2a5 Merge pull request #22937 from dfm:custom-vmap-errors
PiperOrigin-RevId: 660880442
2024-08-08 10:05:34 -07:00
jax authors
647a2f75d3 Merge pull request #22947 from mattjj:22944
PiperOrigin-RevId: 660874340
2024-08-08 09:49:54 -07:00
Jieying Luo
ccc27a7a5f Remove PJRT version check in memories_test.py that is no longer needed.
0.43 is the version at 2024 Feb. Cloud TPU CI uses 20240228 so it should contain the PJRT C API needed for the test d3b6066f91/.github/workflows/cloud-tpu-ci-nightly.yml (L35).

PiperOrigin-RevId: 660869710
2024-08-08 09:35:41 -07:00
Matthew Johnson
44ae9b30ec fix #22944 2024-08-08 16:19:19 +00:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Adam Paszke
04a753ad02 [Mosaic TPU] Improve an error message in case someone tries to extract a non-32-bit scalar.
PiperOrigin-RevId: 660826696
2024-08-08 07:22:10 -07:00
Dan Foreman-Mackey
595ca0affa Improve error message for missing vmap rule in custom_vmap.
This is a partial re-land of https://github.com/google/jax/pull/22869
after it was rolled back to fix internal users. This part of the change
didn't cause the issues, and I'll follow up with the rest of the changes
in a second PR.
2024-08-08 14:08:51 +01:00
Jake VanderPlas
551f72979c Rollback of #22869
This is causing breakages due to overly-restrictive checks on kwargs

Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be

PiperOrigin-RevId: 660803968
2024-08-08 06:00:17 -07:00
Jake VanderPlas
4ca341701f Improve documentation for jnp.piecewise & jnp.select 2024-08-08 05:53:03 -07:00
jax authors
9fbc51bfad Merge pull request #22923 from Rifur13:faster
PiperOrigin-RevId: 660736990
2024-08-08 01:44:42 -07:00
Adam Paszke
42fe45f34b [Mosaic TPU] Add support for removal of implicit 2nd minor for all 32-bit tilings
PiperOrigin-RevId: 660724215
2024-08-08 01:00:32 -07:00
jax authors
0630139da2 Merge pull request #22925 from google:doc_update
PiperOrigin-RevId: 660721790
2024-08-08 00:50:12 -07:00
Yash Katariya
7f8a4c84d3 Remove PositionalSharding from distributed array doc 2024-08-07 21:25:24 -07:00
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Gleb Pobudzey
e6425a2c67 Small performance improvement to pallas MHA 2024-08-07 23:20:19 +00:00
jax authors
7efca0490f Merge pull request #22920 from jakevdp:fix-lint
PiperOrigin-RevId: 660570457
2024-08-07 16:01:09 -07:00