26414 Commits

Author SHA1 Message Date
charleshofer
1f2fe33091
Enable upstream CI on release branches (#296) 2025-03-18 17:22:37 -05:00
charleshofer
dd7f96b27c
Fix ROCm build README (#284) 2025-03-18 14:35:36 -05:00
charleshofer
c46b4fc02b
Merge pull request #294 from ROCm/ci-upstream-sync-151_1
CI: 03/18/25 upstream sync
2025-03-18 11:10:16 -05:00
Charles Hofer
c7b407c9f0 Merge branch 'rocm-main' into ci-upstream-sync-151_1 2025-03-18 15:27:35 +00:00
Benjamin Chetioui
1e36cbe597 [Mosaic GPU] Raise a NotImplementedError if swizzle=16.
Unswizzled MMAs don't lower correctly, and are not currently intended to be
supported.

PiperOrigin-RevId: 737981373
2025-03-18 06:29:13 -07:00
Adam Paszke
8da93249d2 [Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.

PiperOrigin-RevId: 737967890
2025-03-18 05:38:49 -07:00
jax authors
7a459f0ed1 Update XLA dependency to use revision
3bb7654721.

PiperOrigin-RevId: 737959582
2025-03-18 05:01:30 -07:00
Benjamin Chetioui
ba2f7c9ad9 [Mosaic GPU] Add transform inference rule for mgpu.slice_smem.
PiperOrigin-RevId: 737957778
2025-03-18 04:53:54 -07:00
Adam Paszke
d4bd2570ae [Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts
PiperOrigin-RevId: 737956598
2025-03-18 04:47:51 -07:00
Chris Jones
38d52a19ef [mosaic_gpu] Force flush all cupti activity, then unsubscribe.
With default flushing, it is possible for events to be missed. We should only unsubscribe after we are finished with cupti.

PiperOrigin-RevId: 737939327
2025-03-18 03:35:03 -07:00
Adam Paszke
34cd5b0d74 [Mosaic GPU] Remove sub-byte conversion restriction
XLA:GPU recently changed its endianness to little endian to better match LLVM
and the rest of the CUDA ecosystem, so we can lift the earlier restrictions.
PiperOrigin-RevId: 737934373
2025-03-18 03:13:21 -07:00
Yash Katariya
549973dec6 Allow pspec to be passed to device_put if there is a mesh in the surrounding context
PiperOrigin-RevId: 737812111
2025-03-17 17:47:56 -07:00
Changhui Lin
f174b00f23 Replace the uses of PjRtClient::Compile() with PjRtClient::CompileAndLoad().
This is to prepare for updating `PjRtClient::Compile()` to return an unloaded executable [1/N]

PiperOrigin-RevId: 737805623
2025-03-17 17:18:31 -07:00
Emily Fertig
8c35191725 Enable jax.device_put to a sharding with no local devices.
PiperOrigin-RevId: 737797815
2025-03-17 16:49:46 -07:00
Sergei Lebedev
051687dc4c [pallas] pallas_call_p is now parameterized by a mesh
The mesh is necessary to add support for clusters to the Mosaic GPU backend.

PiperOrigin-RevId: 737792129
2025-03-17 16:30:40 -07:00
jax authors
b4966130a3 Compute tile index using tile-based coordinates
This reduces the chances of overflowing a 32-bit integer when computing tile indices.
Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`.

PiperOrigin-RevId: 737778853
2025-03-17 15:46:27 -07:00
jax authors
b74b16f9b9 Merge pull request #27164 from MichaelHudgins:a4-testing
PiperOrigin-RevId: 737733904
2025-03-17 13:36:04 -07:00
Michael Hudgins
ecf7fde714 Add B200 testing to continuous workflow 2025-03-17 20:19:20 +00:00
jax authors
4f70471310 Fix error in pallas tutorial
PiperOrigin-RevId: 737727935
2025-03-17 13:19:12 -07:00
Peter Hawkins
20658fabb3 Replace cached function get_replicated_hlo_sharding() with a constant.
Small cleanup, no functional changes intended.

PiperOrigin-RevId: 737727727
2025-03-17 13:17:33 -07:00
jax authors
ebcae0d30a Merge pull request #26980 from carlosgmartin:categorical_replace
PiperOrigin-RevId: 737720590
2025-03-17 12:58:01 -07:00
Peter Hawkins
be5d13af77 Remove code that preserved _original_py_fns on C++ classes.
This no longer appears to be used.

PiperOrigin-RevId: 737715578
2025-03-17 12:43:04 -07:00
Benjamin Chetioui
9a686e0bf3 [Mosaic GPU] Add initial transform inference rules for vector.{load,store}.
PiperOrigin-RevId: 737703568
2025-03-17 12:08:07 -07:00
carlosgmartin
3f59fa6888 Add replace option to random.categorical to enable sampling without replacement. 2025-03-17 13:41:46 -04:00
Mathew Odden
d864b4fbf4
Fix auditwheel version issue (#288)
Auditwheel 6.3.0 changed/removed the lddtree function
so cap constraint to 6.2.x
2025-03-17 12:30:56 -05:00
jax authors
de9ad6bad9 Merge pull request #27157 from mar-muel:improve-random-choice-performance
PiperOrigin-RevId: 737665351
2025-03-17 10:30:15 -07:00
Nitin Srinivasan
031614c22b Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt
PiperOrigin-RevId: 737632771
2025-03-17 08:59:06 -07:00
Adam Paszke
3649da56fc [Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length
We can now perform the conversion in groups of 2, 4 or even 8 elements at a time.

PiperOrigin-RevId: 737626600
2025-03-17 08:37:17 -07:00
Sergei Lebedev
0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
jax authors
55812c5d02 Update XLA dependency to use revision
fcf97e619e.

PiperOrigin-RevId: 737581187
2025-03-17 05:44:17 -07:00
Sergei Lebedev
a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
Adam Paszke
89b21de62a [Mosaic GPU] Add support for changing the layout before the upcast
This lets us save on 2 ALU instructions (3x select becomes 1x prmt).

PiperOrigin-RevId: 737550598
2025-03-17 03:26:48 -07:00
Adam Paszke
2bdd9c8797 [Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16-bit upcast
PiperOrigin-RevId: 737542885
2025-03-17 02:52:16 -07:00
jax authors
761b35c59e Merge pull request #27176 from jakevdp:lax-docs
PiperOrigin-RevId: 737338493
2025-03-16 05:39:55 -07:00
jax authors
e8b683aee0 Update XLA dependency to use revision
936a727db7.

PiperOrigin-RevId: 737338103
2025-03-16 05:37:58 -07:00
Joan Puigcerver
466ef6a132 Change the way that batching.spec_types is updated.
There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.

Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).

When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.

PiperOrigin-RevId: 737280270
2025-03-15 22:58:44 -07:00
Jake VanderPlas
de8b0564ce Better docs for jax.lax add/sub/mul/div 2025-03-15 11:49:51 -07:00
jax authors
f360e19194 Update XLA dependency to use revision
f52d5e03ce.

PiperOrigin-RevId: 737143950
2025-03-15 05:10:35 -07:00
Ayaka
9b0ace4a11 Support error checking in explicit mode
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00
jax authors
d07d642d6f Merge pull request #27177 from jax-ml:mixing_modes
PiperOrigin-RevId: 737047069
2025-03-14 18:34:27 -07:00
Yash Katariya
3c0027af3b mixing modes 2025-03-14 18:23:27 -07:00
jax authors
7db59cdcca Merge pull request #27174 from mattjj:opt-barrier-ad-rules
PiperOrigin-RevId: 737040381
2025-03-14 17:59:07 -07:00
Peter Hawkins
14cb7453f0 Add a C++ implementation of a toplogical sort.
This is an exact port of the current Python implementation to C++ for speed.

I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.

PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
github-actions[bot]
7a172d7010
Merge pull request #285 from ROCm/ci-upstream-sync-147_1
CI: 03/14/25 upstream sync
2025-03-14 17:59:16 -05:00
Charles Hofer
7a6940bb7b Trivial change for CI 2025-03-14 22:45:17 +00:00
GitHub Actions
e275d5cf6c Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-147_1 2025-03-14 22:42:07 +00:00
Matthew Johnson
dadc68b6c1 add experimental lax.optimization_barrier autodiff rules 2025-03-14 22:40:55 +00:00
jax authors
b00a3a1986 Merge pull request #27015 from mattjj:direct-linearize-fixes-4
PiperOrigin-RevId: 737003323
2025-03-14 15:24:11 -07:00
Sergei Lebedev
64230d1c93 [pallas:mosaic_gpu] WG lowering now supports while_p
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
charleshofer
022da913e6
Count test totals correctly for dashboards (#280)
* Account test totals correctly for dashboards

* Add blurb to the dev guide on skipping tests

* Remove extra newline

* Default to 0 if "skipped" isn't found

Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>

---------

Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
2025-03-14 16:57:08 -05:00