Sergei Lebedev
d8eafc8ee3
Disabled nn_test under asan on TPU as well, since it also times out
...
PiperOrigin-RevId: 660950262
2024-08-08 13:02:31 -07:00
Dan Foreman-Mackey
efb7721671
Remove unnecessary constraint on keyword-only arguments in custom_vjp
with optimize_remat=True
.
...
PiperOrigin-RevId: 660945559
2024-08-08 12:49:27 -07:00
jax authors
93d4629846
Merge pull request #22903 from jakevdp:update-array-api
...
PiperOrigin-RevId: 660941835
2024-08-08 12:39:56 -07:00
Jake VanderPlas
d999208863
[array API] update test suite to most recent commit
2024-08-08 12:33:30 -07:00
Jieying Luo
751b5742fd
Deprecate using build_cuda_plugin_from_source flag and rely on jaxlib_build config.
...
If jaxlib needs to be built from source, cuda plugin will be built from source as well.
PiperOrigin-RevId: 660926791
2024-08-08 11:58:13 -07:00
Yash Katariya
e6303244bf
If the memory kind is the default kind throughout the jaxpr, then revert back to the previous device_put behavior which was a no-op inside jit.
...
This is also the same behavior for arguments and outputs, where we don't insert `mhlo.memory_kind` attributes in the stableHLO if the entire jaxpr only has the default memory kind.
PiperOrigin-RevId: 660913387
2024-08-08 11:24:25 -07:00
jax authors
bdd8f74efe
Merge pull request #22916 from jakevdp:piecewise-doc
...
PiperOrigin-RevId: 660896267
2024-08-08 10:45:09 -07:00
jax authors
0309adf2a5
Merge pull request #22937 from dfm:custom-vmap-errors
...
PiperOrigin-RevId: 660880442
2024-08-08 10:05:34 -07:00
jax authors
647a2f75d3
Merge pull request #22947 from mattjj:22944
...
PiperOrigin-RevId: 660874340
2024-08-08 09:49:54 -07:00
Jieying Luo
ccc27a7a5f
Remove PJRT version check in memories_test.py that is no longer needed.
...
0.43 is the version at 2024 Feb. Cloud TPU CI uses 20240228 so it should contain the PJRT C API needed for the test d3b6066f91/.github/workflows/cloud-tpu-ci-nightly.yml (L35)
.
PiperOrigin-RevId: 660869710
2024-08-08 09:35:41 -07:00
Matthew Johnson
44ae9b30ec
fix #22944
2024-08-08 16:19:19 +00:00
Dan Foreman-Mackey
11d9c2de2c
Update GPU implementation of lu_pivots_to_permutation
to infer the permutation size directly from the input dimensions, instead of using an input parameter.
...
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.
In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.
PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Adam Paszke
04a753ad02
[Mosaic TPU] Improve an error message in case someone tries to extract a non-32-bit scalar.
...
PiperOrigin-RevId: 660826696
2024-08-08 07:22:10 -07:00
Dan Foreman-Mackey
595ca0affa
Improve error message for missing vmap rule in custom_vmap.
...
This is a partial re-land of https://github.com/google/jax/pull/22869
after it was rolled back to fix internal users. This part of the change
didn't cause the issues, and I'll follow up with the rest of the changes
in a second PR.
2024-08-08 14:08:51 +01:00
Jake VanderPlas
551f72979c
Rollback of #22869
...
This is causing breakages due to overly-restrictive checks on kwargs
Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be
PiperOrigin-RevId: 660803968
2024-08-08 06:00:17 -07:00
Jake VanderPlas
4ca341701f
Improve documentation for jnp.piecewise & jnp.select
2024-08-08 05:53:03 -07:00
jax authors
9fbc51bfad
Merge pull request #22923 from Rifur13:faster
...
PiperOrigin-RevId: 660736990
2024-08-08 01:44:42 -07:00
Adam Paszke
42fe45f34b
[Mosaic TPU] Add support for removal of implicit 2nd minor for all 32-bit tilings
...
PiperOrigin-RevId: 660724215
2024-08-08 01:00:32 -07:00
jax authors
0630139da2
Merge pull request #22925 from google:doc_update
...
PiperOrigin-RevId: 660721790
2024-08-08 00:50:12 -07:00
Yash Katariya
7f8a4c84d3
Remove PositionalSharding from distributed array doc
2024-08-07 21:25:24 -07:00
Yash Katariya
be53ee10b1
Set jax_enable_memories
flag to True
by default
...
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Gleb Pobudzey
e6425a2c67
Small performance improvement to pallas MHA
2024-08-07 23:20:19 +00:00
jax authors
7efca0490f
Merge pull request #22920 from jakevdp:fix-lint
...
PiperOrigin-RevId: 660570457
2024-08-07 16:01:09 -07:00
jax authors
a57d6591ee
Update XLA dependency to use revision
...
3bf7e1ae48
.
PiperOrigin-RevId: 660570144
2024-08-07 15:57:41 -07:00
Jake VanderPlas
53af0d4d90
CI: fix mypy errors
2024-08-07 15:15:45 -07:00
jax authors
de02988e94
Merge pull request #22909 from ROCm:ci_fix_solver_paths
...
PiperOrigin-RevId: 660515208
2024-08-07 13:26:17 -07:00
jax authors
cce725059a
Merge pull request #22830 from kaixih:support_vmap
...
PiperOrigin-RevId: 660509938
2024-08-07 13:12:59 -07:00
jax authors
d3b6066f91
Merge pull request #22820 from Rifur13:mha-faster
...
PiperOrigin-RevId: 660461104
2024-08-07 11:11:15 -07:00
jax authors
32131d0288
Merge pull request #22897 from jakevdp:bool-indexing
...
PiperOrigin-RevId: 660444193
2024-08-07 10:30:41 -07:00
Sergei Lebedev
6fc57c0eb6
Rolling forward #22836
...
This version, proposed by @dfm, does not have a custom JVP for the whole
logsumexp and instead fixes #22398 directly.
Reverts e416c6675acfd82866a6e83e8c221640c4d02f29
PiperOrigin-RevId: 660438802
2024-08-07 10:17:55 -07:00
jax authors
893ae6eb80
Merge pull request #22869 from dfm:custom-batching-polish
...
PiperOrigin-RevId: 660421503
2024-08-07 09:40:46 -07:00
jax authors
5cb9510f60
Merge pull request #22908 from gnecula:pallas_warn
...
PiperOrigin-RevId: 660421476
2024-08-07 09:37:15 -07:00
jax authors
930c8ca791
Merge pull request #22914 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 660421322
2024-08-07 09:33:43 -07:00
Ruturaj4
a2d79936df
[ROCM] Fix BUILD.bazel library source paths
2024-08-07 09:18:20 -05:00
Sergei Lebedev
3a1567f57a
Do not run nn_test under asan -- it times out
...
PiperOrigin-RevId: 660377176
2024-08-07 07:14:27 -07:00
rajasekharporeddy
3095c570b8
Better docs for jnp.fft.rfft2 and jnp.fft.irfft2
2024-08-07 17:59:53 +05:30
George Necula
3e5e947542
Move some backwards compatibility tests from jax_triton to jax/pallas.
...
While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu`
PiperOrigin-RevId: 660341331
2024-08-07 05:00:29 -07:00
Sergei Lebedev
28ca734d9b
Added another boxDim check to mosaic_gpu_init_tma_desc
...
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
George Necula
64eb8e9639
[pallas] Add a warning message about experimental and incomplete status
2024-08-07 08:38:56 +03:00
Sharad Vikram
803453ed74
[Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop
...
PiperOrigin-RevId: 660238073
2024-08-06 22:33:15 -07:00
Yash Katariya
dd958adc39
Add mesh_shape
to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
...
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.
Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`
In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses
PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
Yue Sheng
7f44edc01e
Change log level of clearing JAX backend caches from info to debug.
...
PiperOrigin-RevId: 660141868
2024-08-06 16:27:56 -07:00
jax authors
798297af98
Update XLA dependency to use revision
...
08b8d938eb
.
PiperOrigin-RevId: 660133285
2024-08-06 16:05:14 -07:00
jax authors
53ab5eb24f
Merge pull request #22900 from jakevdp:dep-bfloat16
...
PiperOrigin-RevId: 660102762
2024-08-06 14:42:43 -07:00
jax authors
9074e8544f
Add test for zero-sized host memory parameter
...
PiperOrigin-RevId: 660097039
2024-08-06 14:31:41 -07:00
jax authors
aec6efb44b
Merge pull request #22649 from ROCm:ci_jax_export_harness
...
PiperOrigin-RevId: 660096296
2024-08-06 14:27:13 -07:00
jax authors
cc9665749f
Merge pull request #22901 from ROCm:ci_test_harness_vmap
...
PiperOrigin-RevId: 660089572
2024-08-06 14:04:57 -07:00
Jieying Luo
abe7982d65
Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc.
...
The plugin is released and the flag is no longer needed.
Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.
PiperOrigin-RevId: 660059432
2024-08-06 12:39:45 -07:00
Kanglan Tang
ae541203bc
Skip flaky test_weight_offload_with_dp_on_output test on GPU backend.
...
PiperOrigin-RevId: 660057950
2024-08-06 12:35:53 -07:00
Ruturaj4
707cdd4706
[ROCM] Fix hipsolverSsyevd tests due to align with the rocm behavior.
2024-08-06 14:10:09 -05:00