Roy Frostig
371935cc10
update README and several docs to typed RNG keys
2024-08-11 08:09:47 -07:00
jax authors
9e86416a32
Update XLA dependency to use revision
...
dfb2a8b498
.
PiperOrigin-RevId: 661668331
2024-08-10 14:20:14 -07:00
Yash Katariya
c08656c61d
[Rollback] We still want to allow multiple meshes in the user program
...
Reverts dd958adc39550d2758ecdb13809c6d85df7658a2
PiperOrigin-RevId: 661537233
2024-08-09 23:17:46 -07:00
Yash Katariya
abc9ba00e9
Rename count_jit_and_pmap_compiles
to count_jit_and_pmap_lowerings
...
PiperOrigin-RevId: 661496993
2024-08-09 20:03:43 -07:00
jax authors
7a75c96aa9
Update XLA dependency to use revision
...
46e205a0b6
.
PiperOrigin-RevId: 661412627
2024-08-09 14:53:42 -07:00
Parker Schuh
4863a568f9
Fix array_test.py when jax_pmap_no_rank_reduction is flipped to true.
...
The problem is that squeezing was happening on noncommitted arrays
so list(x) was moving all the shards to device 0. This will potentially
cause ooms.
PiperOrigin-RevId: 661408226
2024-08-09 14:40:52 -07:00
Jieying Luo
a3ae5e18d3
Remove build_cuda_plugin_from_source
flag which is no longe used.
...
751b5742fd
PiperOrigin-RevId: 661370449
2024-08-09 12:54:14 -07:00
jax authors
35c2454ea6
Merge pull request #22887 from justinjfu:pallas_distr_docs
...
PiperOrigin-RevId: 661368285
2024-08-09 12:47:48 -07:00
jax authors
3bd3597703
Improves error message in case of invalid sharding mesh
...
PiperOrigin-RevId: 661358450
2024-08-09 12:18:16 -07:00
jax authors
aa334145b4
Merge pull request #22958 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 661340981
2024-08-09 11:30:16 -07:00
jax authors
c207ad4c04
Merge pull request #22960 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 661319343
2024-08-09 10:37:44 -07:00
rajasekharporeddy
ff1f199d09
Improved docs for jnp.fft.rfftn and jnp.fft.irfftn
2024-08-09 23:07:17 +05:30
Justin Fu
47842b358d
Merge pull request #22887 from justinjfu/pallas_distr_docs
...
[Pallas] Add pallas distributed computation tutorial
2024-08-09 09:56:07 -07:00
Tom Hennigan
5ced6db692
Cache _get_tpu_generation
to avoid repeated calls to jax.devices().
...
We use `util.cache` such that if the default backend changes this function
will (correctly) be re-evaluated.
PiperOrigin-RevId: 661293560
2024-08-09 09:33:45 -07:00
jax authors
40e67c73ee
Merge pull request #22373 from gspschmid:gschmid/jax-local-device-ids-from-env
...
PiperOrigin-RevId: 661291774
2024-08-09 09:28:39 -07:00
Tomás Longeri
77afe251e7
[Mosaic TPU][Python] Check validity of VectorLayout on init
...
PiperOrigin-RevId: 661226283
2024-08-09 05:28:00 -07:00
Georg Stefan Schmid
f9bc4c643b
[jax.distributed] Allow setting local device ids via env var
2024-08-09 10:23:17 +00:00
rajasekharporeddy
6ee1555d21
Fix broken links in jnp.fft.fftfreq and jnp.fft.rfftfreq
2024-08-09 14:23:56 +05:30
Tomás Longeri
e57a7e3f05
[Mosaic] Column shift relayouts for non-native tilings and packed types, except for (1, n) and packed
...
PiperOrigin-RevId: 661091012
2024-08-08 20:14:08 -07:00
Justin Fu
bdb03309b5
Merge branch 'main' into pallas_distr_docs
2024-08-08 17:36:36 -07:00
Justin Fu
dcd186f552
[Pallas] Add pallas distributed computation tutorial
2024-08-08 17:34:35 -07:00
jax authors
f2068bb4ad
Update XLA dependency to use revision
...
76978f280d
.
PiperOrigin-RevId: 660999885
2024-08-08 15:18:44 -07:00
Jieying Luo
9c2caedab1
Add subdirectories to the output path when building editable wheels for jaxlib and GPU plugin.
...
When `build_gpu_plugin` is true, three wheels will be produced (jaxlib, jax-cuda-pjrt and jax-cuda-plugin). If they are editable, they need to be placed in subdirectories to avoid overwrite.
Tested on GPU. After the editable wheels are built, they can be installed with `pip install -e /jax/dist/jax_gpu_pjrt /jax/dist/jaxlib /jax/dist/jax_gpu_plugin`.
PiperOrigin-RevId: 660984311
2024-08-08 14:34:55 -07:00
Justin Fu
deefbdd626
Temporarily disable broken tests in tpu_pallas_pipeline_test.py
...
PiperOrigin-RevId: 660972804
2024-08-08 14:04:04 -07:00
Sergei Lebedev
12a9c8cfd4
Pallas Mosaic GPU lowering now supports (at least the basic) pl.BlockSpecs
...
Note that we still don't do any pipelining whatsoever, but it can be done once
this change lands.
PiperOrigin-RevId: 660969393
2024-08-08 13:55:07 -07:00
Gleb Pobudzey
d28d14917e
Fix error message in dot_product_attention
...
PiperOrigin-RevId: 660960409
2024-08-08 13:30:21 -07:00
Sergei Lebedev
d8eafc8ee3
Disabled nn_test under asan on TPU as well, since it also times out
...
PiperOrigin-RevId: 660950262
2024-08-08 13:02:31 -07:00
Dan Foreman-Mackey
efb7721671
Remove unnecessary constraint on keyword-only arguments in custom_vjp
with optimize_remat=True
.
...
PiperOrigin-RevId: 660945559
2024-08-08 12:49:27 -07:00
jax authors
93d4629846
Merge pull request #22903 from jakevdp:update-array-api
...
PiperOrigin-RevId: 660941835
2024-08-08 12:39:56 -07:00
Jake VanderPlas
d999208863
[array API] update test suite to most recent commit
2024-08-08 12:33:30 -07:00
Jieying Luo
751b5742fd
Deprecate using build_cuda_plugin_from_source flag and rely on jaxlib_build config.
...
If jaxlib needs to be built from source, cuda plugin will be built from source as well.
PiperOrigin-RevId: 660926791
2024-08-08 11:58:13 -07:00
Yash Katariya
e6303244bf
If the memory kind is the default kind throughout the jaxpr, then revert back to the previous device_put behavior which was a no-op inside jit.
...
This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind.
PiperOrigin-RevId: 660913387
2024-08-08 11:24:25 -07:00
jax authors
bdd8f74efe
Merge pull request #22916 from jakevdp:piecewise-doc
...
PiperOrigin-RevId: 660896267
2024-08-08 10:45:09 -07:00
jax authors
0309adf2a5
Merge pull request #22937 from dfm:custom-vmap-errors
...
PiperOrigin-RevId: 660880442
2024-08-08 10:05:34 -07:00
jax authors
647a2f75d3
Merge pull request #22947 from mattjj:22944
...
PiperOrigin-RevId: 660874340
2024-08-08 09:49:54 -07:00
Jieying Luo
ccc27a7a5f
Remove PJRT version check in memories_test.py that is no longer needed.
...
0.43 is the version at 2024 Feb. Cloud TPU CI uses 20240228 so it should contain the PJRT C API needed for the test d3b6066f91/.github/workflows/cloud-tpu-ci-nightly.yml (L35)
.
PiperOrigin-RevId: 660869710
2024-08-08 09:35:41 -07:00
Matthew Johnson
44ae9b30ec
fix #22944
2024-08-08 16:19:19 +00:00
Dan Foreman-Mackey
11d9c2de2c
Update GPU implementation of lu_pivots_to_permutation
to infer the permutation size directly from the input dimensions, instead of using an input parameter.
...
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.
In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.
PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Adam Paszke
04a753ad02
[Mosaic TPU] Improve an error message in case someone tries to extract a non-32-bit scalar.
...
PiperOrigin-RevId: 660826696
2024-08-08 07:22:10 -07:00
Dan Foreman-Mackey
595ca0affa
Improve error message for missing vmap rule in custom_vmap.
...
This is a partial re-land of https://github.com/google/jax/pull/22869
after it was rolled back to fix internal users. This part of the change
didn't cause the issues, and I'll follow up with the rest of the changes
in a second PR.
2024-08-08 14:08:51 +01:00
Jake VanderPlas
551f72979c
Rollback of #22869
...
This is causing breakages due to overly-restrictive checks on kwargs
Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be
PiperOrigin-RevId: 660803968
2024-08-08 06:00:17 -07:00
Jake VanderPlas
4ca341701f
Improve documentation for jnp.piecewise & jnp.select
2024-08-08 05:53:03 -07:00
jax authors
9fbc51bfad
Merge pull request #22923 from Rifur13:faster
...
PiperOrigin-RevId: 660736990
2024-08-08 01:44:42 -07:00
Adam Paszke
42fe45f34b
[Mosaic TPU] Add support for removal of implicit 2nd minor for all 32-bit tilings
...
PiperOrigin-RevId: 660724215
2024-08-08 01:00:32 -07:00
jax authors
0630139da2
Merge pull request #22925 from google:doc_update
...
PiperOrigin-RevId: 660721790
2024-08-08 00:50:12 -07:00
Yash Katariya
7f8a4c84d3
Remove PositionalSharding from distributed array doc
2024-08-07 21:25:24 -07:00
Yash Katariya
be53ee10b1
Set jax_enable_memories
flag to True
by default
...
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Gleb Pobudzey
e6425a2c67
Small performance improvement to pallas MHA
2024-08-07 23:20:19 +00:00
jax authors
7efca0490f
Merge pull request #22920 from jakevdp:fix-lint
...
PiperOrigin-RevId: 660570457
2024-08-07 16:01:09 -07:00
jax authors
a57d6591ee
Update XLA dependency to use revision
...
3bf7e1ae48
.
PiperOrigin-RevId: 660570144
2024-08-07 15:57:41 -07:00