22361 Commits

Author SHA1 Message Date
Sergei Lebedev
d8eafc8ee3 Disabled nn_test under asan on TPU as well, since it also times out
PiperOrigin-RevId: 660950262
2024-08-08 13:02:31 -07:00
Dan Foreman-Mackey
efb7721671 Remove unnecessary constraint on keyword-only arguments in custom_vjp with optimize_remat=True.
PiperOrigin-RevId: 660945559
2024-08-08 12:49:27 -07:00
jax authors
93d4629846 Merge pull request #22903 from jakevdp:update-array-api
PiperOrigin-RevId: 660941835
2024-08-08 12:39:56 -07:00
Jake VanderPlas
d999208863 [array API] update test suite to most recent commit 2024-08-08 12:33:30 -07:00
Jieying Luo
751b5742fd Deprecate using build_cuda_plugin_from_source flag and rely on jaxlib_build config.
If jaxlib needs to be built from source, cuda plugin will be built from source as well.

PiperOrigin-RevId: 660926791
2024-08-08 11:58:13 -07:00
Yash Katariya
e6303244bf If the memory kind is the default kind throughout the jaxpr, then revert back to the previous device_put behavior which was a no-op inside jit.
This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind.

PiperOrigin-RevId: 660913387
2024-08-08 11:24:25 -07:00
jax authors
bdd8f74efe Merge pull request #22916 from jakevdp:piecewise-doc
PiperOrigin-RevId: 660896267
2024-08-08 10:45:09 -07:00
jax authors
0309adf2a5 Merge pull request #22937 from dfm:custom-vmap-errors
PiperOrigin-RevId: 660880442
2024-08-08 10:05:34 -07:00
jax authors
647a2f75d3 Merge pull request #22947 from mattjj:22944
PiperOrigin-RevId: 660874340
2024-08-08 09:49:54 -07:00
Jieying Luo
ccc27a7a5f Remove PJRT version check in memories_test.py that is no longer needed.
0.43 is the version at 2024 Feb. Cloud TPU CI uses 20240228 so it should contain the PJRT C API needed for the test d3b6066f91/.github/workflows/cloud-tpu-ci-nightly.yml (L35).

PiperOrigin-RevId: 660869710
2024-08-08 09:35:41 -07:00
Matthew Johnson
44ae9b30ec fix #22944 2024-08-08 16:19:19 +00:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Adam Paszke
04a753ad02 [Mosaic TPU] Improve an error message in case someone tries to extract a non-32-bit scalar.
PiperOrigin-RevId: 660826696
2024-08-08 07:22:10 -07:00
Dan Foreman-Mackey
595ca0affa Improve error message for missing vmap rule in custom_vmap.
This is a partial re-land of https://github.com/google/jax/pull/22869
after it was rolled back to fix internal users. This part of the change
didn't cause the issues, and I'll follow up with the rest of the changes
in a second PR.
2024-08-08 14:08:51 +01:00
Jake VanderPlas
551f72979c Rollback of #22869
This is causing breakages due to overly-restrictive checks on kwargs

Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be

PiperOrigin-RevId: 660803968
2024-08-08 06:00:17 -07:00
Jake VanderPlas
4ca341701f Improve documentation for jnp.piecewise & jnp.select 2024-08-08 05:53:03 -07:00
jax authors
9fbc51bfad Merge pull request #22923 from Rifur13:faster
PiperOrigin-RevId: 660736990
2024-08-08 01:44:42 -07:00
Adam Paszke
42fe45f34b [Mosaic TPU] Add support for removal of implicit 2nd minor for all 32-bit tilings
PiperOrigin-RevId: 660724215
2024-08-08 01:00:32 -07:00
jax authors
0630139da2 Merge pull request #22925 from google:doc_update
PiperOrigin-RevId: 660721790
2024-08-08 00:50:12 -07:00
Yash Katariya
7f8a4c84d3 Remove PositionalSharding from distributed array doc 2024-08-07 21:25:24 -07:00
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Gleb Pobudzey
e6425a2c67 Small performance improvement to pallas MHA 2024-08-07 23:20:19 +00:00
jax authors
7efca0490f Merge pull request #22920 from jakevdp:fix-lint
PiperOrigin-RevId: 660570457
2024-08-07 16:01:09 -07:00
jax authors
a57d6591ee Update XLA dependency to use revision
3bf7e1ae48.

PiperOrigin-RevId: 660570144
2024-08-07 15:57:41 -07:00
Jake VanderPlas
53af0d4d90 CI: fix mypy errors 2024-08-07 15:15:45 -07:00
jax authors
de02988e94 Merge pull request #22909 from ROCm:ci_fix_solver_paths
PiperOrigin-RevId: 660515208
2024-08-07 13:26:17 -07:00
jax authors
cce725059a Merge pull request #22830 from kaixih:support_vmap
PiperOrigin-RevId: 660509938
2024-08-07 13:12:59 -07:00
jax authors
d3b6066f91 Merge pull request #22820 from Rifur13:mha-faster
PiperOrigin-RevId: 660461104
2024-08-07 11:11:15 -07:00
jax authors
32131d0288 Merge pull request #22897 from jakevdp:bool-indexing
PiperOrigin-RevId: 660444193
2024-08-07 10:30:41 -07:00
Sergei Lebedev
6fc57c0eb6 Rolling forward #22836
This version, proposed by @dfm, does not have a custom JVP for the whole
logsumexp and instead fixes #22398 directly.

Reverts e416c6675acfd82866a6e83e8c221640c4d02f29

PiperOrigin-RevId: 660438802
2024-08-07 10:17:55 -07:00
jax authors
893ae6eb80 Merge pull request #22869 from dfm:custom-batching-polish
PiperOrigin-RevId: 660421503
2024-08-07 09:40:46 -07:00
jax authors
5cb9510f60 Merge pull request #22908 from gnecula:pallas_warn
PiperOrigin-RevId: 660421476
2024-08-07 09:37:15 -07:00
jax authors
930c8ca791 Merge pull request #22914 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 660421322
2024-08-07 09:33:43 -07:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
Sergei Lebedev
3a1567f57a Do not run nn_test under asan -- it times out
PiperOrigin-RevId: 660377176
2024-08-07 07:14:27 -07:00
rajasekharporeddy
3095c570b8 Better docs for jnp.fft.rfft2 and jnp.fft.irfft2 2024-08-07 17:59:53 +05:30
George Necula
3e5e947542 Move some backwards compatibility tests from jax_triton to jax/pallas.
While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu`

PiperOrigin-RevId: 660341331
2024-08-07 05:00:29 -07:00
Sergei Lebedev
28ca734d9b Added another boxDim check to mosaic_gpu_init_tma_desc
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
George Necula
64eb8e9639 [pallas] Add a warning message about experimental and incomplete status 2024-08-07 08:38:56 +03:00
Sharad Vikram
803453ed74 [Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop
PiperOrigin-RevId: 660238073
2024-08-06 22:33:15 -07:00
Yash Katariya
dd958adc39 Add mesh_shape to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.

Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`

In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses

PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
Yue Sheng
7f44edc01e Change log level of clearing JAX backend caches from info to debug.
PiperOrigin-RevId: 660141868
2024-08-06 16:27:56 -07:00
jax authors
798297af98 Update XLA dependency to use revision
08b8d938eb.

PiperOrigin-RevId: 660133285
2024-08-06 16:05:14 -07:00
jax authors
53ab5eb24f Merge pull request #22900 from jakevdp:dep-bfloat16
PiperOrigin-RevId: 660102762
2024-08-06 14:42:43 -07:00
jax authors
9074e8544f Add test for zero-sized host memory parameter
PiperOrigin-RevId: 660097039
2024-08-06 14:31:41 -07:00
jax authors
aec6efb44b Merge pull request #22649 from ROCm:ci_jax_export_harness
PiperOrigin-RevId: 660096296
2024-08-06 14:27:13 -07:00
jax authors
cc9665749f Merge pull request #22901 from ROCm:ci_test_harness_vmap
PiperOrigin-RevId: 660089572
2024-08-06 14:04:57 -07:00
Jieying Luo
abe7982d65 Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc.
The plugin is released and the flag is no longer needed.

Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.

PiperOrigin-RevId: 660059432
2024-08-06 12:39:45 -07:00
Kanglan Tang
ae541203bc Skip flaky test_weight_offload_with_dp_on_output test on GPU backend.
PiperOrigin-RevId: 660057950
2024-08-06 12:35:53 -07:00
Ruturaj4
707cdd4706 [ROCM] Fix hipsolverSsyevd tests due to align with the rocm behavior. 2024-08-06 14:10:09 -05:00