23755 Commits

Author SHA1 Message Date
Yash Katariya
2153de4ce0 [sharding_in_types] If out_aval.sharding is not None and the user specified out_sharding is None, concretize it with the device assignment available and add it to the final out_shardings that's used for lowering and compilation.
This will allow us to return the exact sharding spec that sharding propagation rules figured out.

PiperOrigin-RevId: 687349015
2024-10-18 10:27:58 -07:00
Christos Perivolaropoulos
f8a3c0366b [pallas] run_scoped now supports partial discharge.
PiperOrigin-RevId: 687347284
2024-10-18 10:22:31 -07:00
Benjamin Chetioui
ade480ff05 Add a dialect for Mosaic GPU.
PiperOrigin-RevId: 687325692
2024-10-18 09:11:31 -07:00
jax authors
eba5748094 Disable breaking test-case
PiperOrigin-RevId: 687320199
2024-10-18 08:54:36 -07:00
Adam Paszke
e138e8e49d [Pallas:MGPU] Fix docstring for commit_shared
PiperOrigin-RevId: 687308732
2024-10-18 08:16:55 -07:00
Tom Hennigan
86155561fb nit: Use frozen dataclasses rather than unsafe_hash.
PiperOrigin-RevId: 687267707
2024-10-18 05:35:54 -07:00
Adam Paszke
4094564815 [Pallas:MGPU] Force alignment of SMEM allocations to 1024 bytes
This is to avoid issues when small buffers throw off the alignment for large TMA and WGMMA
operands. We should make this more refined in the future, but this should be enough for now.

PiperOrigin-RevId: 687264994
2024-10-18 05:21:53 -07:00
Adam Paszke
0ee9531ef2 [Pallas:MGPU] Add support for indexed refs to WGMMA
PiperOrigin-RevId: 687258992
2024-10-18 04:55:34 -07:00
Adam Paszke
f2edc83af3 [Pallas:MGPU] Properly commute indexing with other transforms
Doing so requires us to modify the other transforms when we attempt to
move indexing before them.

PiperOrigin-RevId: 687240515
2024-10-18 03:39:51 -07:00
Yash Katariya
4db212d2c6 Add _sharding argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.
This is required because `jax.nn.one_hot` calls into `broascasted_iota`.

PiperOrigin-RevId: 687152343
2024-10-17 21:16:51 -07:00
jax authors
dd5426301a Allow simple host call that uses host tensor as parameter/result in
linear layout. This cl only handles very simple host call patterns.
A more thorough implementation of propagation of T(1)S(5) will be done
later.

This cl doesn't handle host call that passes/returns tensors that
live on device with linear layout either, which will also be impelmented
separately.

PiperOrigin-RevId: 687113203
2024-10-17 18:22:46 -07:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
Yash Katariya
3e634d9530 [sharding_in_types] Add lax.transpose sharding propagation rule
PiperOrigin-RevId: 687094297
2024-10-17 17:08:04 -07:00
Yash Katariya
57a95a77ff [sharding_in_types] Support jnp.array with sharding_in_types. When the input array has a sharding, propagate it through without dropping the sharding.
PiperOrigin-RevId: 687089357
2024-10-17 16:51:41 -07:00
Yash Katariya
5df4878ad0 [sharding_in_types] Add reduce max, integer_pow and standard_unop sharding rules
PiperOrigin-RevId: 687073144
2024-10-17 15:55:29 -07:00
Yash Katariya
e92e1191b3 [sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
2024-10-17 14:55:10 -07:00
jax authors
93389ab5f4 Update XLA dependency to use revision
70df652679.

PiperOrigin-RevId: 687045334
2024-10-17 14:29:44 -07:00
jax authors
919f7c8684 Merge pull request #24345 from phu0ngng:cuda_custom_call
PiperOrigin-RevId: 687034466
2024-10-17 13:57:15 -07:00
Adam Paszke
2d78b17226 [Pallas:MGPU] Add support for transforms in user-specified async copies
PiperOrigin-RevId: 687019020
2024-10-17 13:10:45 -07:00
jax authors
6c2649fdf2 Rewrite mosaic concat to support operand shapes that do not align with native shapes, Expand tests to cover multi operand, batch dim concat, etc.
PiperOrigin-RevId: 687003778
2024-10-17 12:24:51 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
jax authors
1b5cf5a494 Fix breaking test-case
PiperOrigin-RevId: 686932281
2024-10-17 08:57:15 -07:00
Sergei Lebedev
de7beb91a7 [pallas:mosaic_gpu] Added layout_cast
PiperOrigin-RevId: 686917796
2024-10-17 08:08:05 -07:00
Adam Paszke
0519db15ab [Pallas:MGPU] Add lowerings for more ops
PiperOrigin-RevId: 686910947
2024-10-17 07:42:56 -07:00
Adam Paszke
83b61cc988 [Pallas:MGPU] Expose mgpu.commit_shared
PiperOrigin-RevId: 686906355
2024-10-17 07:27:06 -07:00
Adam Paszke
f72376ae0a [Pallas:MGPU] Add support for debug_print of arrays that use the WGMMA layout
PiperOrigin-RevId: 686885229
2024-10-17 06:06:16 -07:00
Adam Paszke
ef361f05a4 [Mosaic GPU] Add support for launching multiple warpgroups using core_map
PiperOrigin-RevId: 686876014
2024-10-17 05:30:48 -07:00
jax authors
3bdc57dd29 Merge pull request #24300 from ROCm:ci_rocm_readme
PiperOrigin-RevId: 686872994
2024-10-17 05:21:13 -07:00
jax authors
36ec513a8d Merge pull request #24355 from gnecula:exp_fix_doc
PiperOrigin-RevId: 686836909
2024-10-17 02:50:15 -07:00
Tzu-Wei Sung
a01f187ec4 Reverts 30ba7f37e027dfe80f9aaad6cd52d9387dba8612
PiperOrigin-RevId: 686833554
2024-10-17 02:36:50 -07:00
jax authors
9027fb38fe Fix segfault
PiperOrigin-RevId: 686821923
2024-10-17 01:52:44 -07:00
jax authors
96d5542aae Support single-process AutoPGLE usage.
PiperOrigin-RevId: 686819261
2024-10-17 01:43:58 -07:00
jax authors
2e5920db76 Merge pull request #24326 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 686807960
2024-10-17 00:59:32 -07:00
George Necula
9aa79bffba [export] Fix github links in the export documentation
Reflects the repo change google/jax -> jax-ml/jax.
Also changes the error message to put the link to the documentation
in a more visible place.
2024-10-17 08:30:28 +01:00
rajasekharporeddy
aaefa82230 Better docs for jax.numpy:equal and not_equal 2024-10-17 07:59:04 +05:30
jax authors
f332dd561a Merge pull request #24351 from jakevdp:sinc-doc
PiperOrigin-RevId: 686715954
2024-10-16 18:40:56 -07:00
jax authors
7d7b619ec2 Merge pull request #24348 from jakevdp:fix-unused-imports
PiperOrigin-RevId: 686706274
2024-10-16 18:03:52 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
jax authors
6a00055980 Merge pull request #24347 from jakevdp:ruff-rules
PiperOrigin-RevId: 686693247
2024-10-16 17:10:36 -07:00
Jake VanderPlas
e1f280c843 CI: enable additional ruff formatting checks 2024-10-16 16:09:54 -07:00
Jake VanderPlas
6ca7228d0d Improve docs for jnp.sinc 2024-10-16 16:08:39 -07:00
jax authors
4d6064beee Merge pull request #24343 from jakevdp:fix-pyi
PiperOrigin-RevId: 686674487
2024-10-16 16:05:26 -07:00
Ruturaj4
3c3b08dfd6 [ROCm] Fix README.md to update AMD JAX installation instructions 2024-10-16 17:15:32 -05:00
jax authors
ebac2e4421 Merge pull request #24323 from selamw1:tile_doc
PiperOrigin-RevId: 686637204
2024-10-16 14:11:44 -07:00
jax authors
a5b312378f Update XLA dependency to use revision
a0a21d14c1.

PiperOrigin-RevId: 686636541
2024-10-16 14:09:36 -07:00
jax authors
1e2d4bc527 Merge pull request #24322 from andportnoy:aportnoy/skip-flash-attention-unless-sm90
PiperOrigin-RevId: 686635331
2024-10-16 14:06:37 -07:00
jax authors
089e4aa904 Merge pull request #24341 from phu0ngng:cuda_graph_ex
PiperOrigin-RevId: 686577115
2024-10-16 11:23:28 -07:00
jax authors
ead1c05ada Merge pull request #23831 from nouiz:doc_policies
PiperOrigin-RevId: 686576725
2024-10-16 11:21:41 -07:00
Jake VanderPlas
b574d2ceb1 Fix aliases in jax.numpy type interface file.
This includes removing some alias declarations for functions that were
previously removed.
2024-10-16 10:40:56 -07:00
Phuong Nguyen
d4bbb4fd84 added cmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:37:49 -07:00