charleshofer
1f93b4b9b5
Add Python and ROCm version matrix for rocm-main ( #314 )
2025-04-02 11:49:03 -05:00
charleshofer
7b20b7c0b7
Use rebase for upstream sync ( #325 )
2025-03-27 11:58:59 -05:00
Mathew Odden
ec4b8ee1ed
Fixes for 0.5.0 build ported to rocm-main
...
(cherry picked from commit c23a81461192a2b6da3d364076a261714d2dc64f)
2025-03-25 17:51:30 -05:00
charleshofer
13d88b6340
Add back raw totals in JSON reports ( #281 )
2025-03-24 11:26:31 -05:00
rocm-repo-management-api-2[bot]
b505df9973
Merge pull request #299 from ROCm/ci-upstream-sync-152_1
...
CI: 03/19/25 upstream sync
2025-03-19 07:20:19 -05:00
jax authors
e9ce8fb92d
Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
...
PiperOrigin-RevId: 738235363
2025-03-18 20:22:27 -07:00
jax authors
f3b7c5cb9e
Integrate LLVM at llvm/llvm-project@0230d63b4a
...
Updates LLVM usage to match
[0230d63b4a8b](https://github.com/llvm/llvm-project/commit/0230d63b4a8b )
PiperOrigin-RevId: 738222096
2025-03-18 19:23:20 -07:00
Sharad Vikram
e949effcda
[Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary computations being staged out in block functions)
...
PiperOrigin-RevId: 738218113
2025-03-18 19:00:41 -07:00
Sharad Vikram
4d715753c4
Make sure to DCE read effects
...
PiperOrigin-RevId: 738215055
2025-03-18 18:42:14 -07:00
jax authors
8c7a55ea82
Update XLA dependency to use revision
...
df971129bd
.
PiperOrigin-RevId: 738213047
2025-03-18 18:33:23 -07:00
Yash Katariya
663ef7ae01
Check the type of mesh in use_abstract_mesh
and use_concrete_mesh
...
PiperOrigin-RevId: 738190879
2025-03-18 16:57:40 -07:00
Peter Hawkins
3f91b4b43a
Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
...
Move the common jaxlib/gpu_plugin_extension into jaxlib/gpu/
Cleanup only, no functional changes intended.
PiperOrigin-RevId: 738183402
2025-03-18 16:29:37 -07:00
jax authors
01a110c4c9
Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
...
PiperOrigin-RevId: 738171437
2025-03-18 15:51:15 -07:00
charleshofer
1f2fe33091
Enable upstream CI on release branches ( #296 )
2025-03-18 17:22:37 -05:00
Parker Schuh
0fb59747f0
Support tuples in custom_partitioning.
...
PiperOrigin-RevId: 738154413
2025-03-18 14:57:08 -07:00
jax authors
080804c78d
Fix logging_test fails on Linux with NVIDIA Driver only.
...
Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux).
Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed.
Solution: Save python program to file in runtime directory then run script from the file.
PiperOrigin-RevId: 738152663
2025-03-18 14:51:35 -07:00
Gleb Pobudzey
54691b125a
[Mosaic GPU] Support reads/writes from SMEM to WGMMARowFragLayout arrays.
...
PiperOrigin-RevId: 738121106
2025-03-18 13:23:07 -07:00
Yash Katariya
76d9890bb7
Run the stream annotation tests on 2 devices so that it can be tested in TAP
...
PiperOrigin-RevId: 738113725
2025-03-18 13:01:48 -07:00
Matthew Johnson
942ff38e36
fix to ragged_all_to_all transpose
...
PiperOrigin-RevId: 738110447
2025-03-18 12:51:21 -07:00
charleshofer
dd7f96b27c
Fix ROCm build README ( #284 )
2025-03-18 14:35:36 -05:00
Jacob Burnim
47e8effdce
Adds option to initialize buffers to NaNs or zeros in TPU interpret mode.
2025-03-18 12:24:45 -07:00
Benjamin Chetioui
875099b25d
[Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering.
...
A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.
PiperOrigin-RevId: 738089782
2025-03-18 11:51:43 -07:00
Peter Hawkins
547d602760
Remove //jaxlib:cpu_kernels and //jaxlib:gpu_kernels forwarding Bazel targets.
...
These were temporary forwarding targets that are no longer needed; use //jaxlib/cpu:cpu_kernels and //jaxlib/cuda:cuda_gpu_kernels instead.
PiperOrigin-RevId: 738085234
2025-03-18 11:39:00 -07:00
jax authors
ee0073e605
Merge pull request #27094 from vfdev-5:fix-tsan-numpy-install-patch
...
PiperOrigin-RevId: 738080051
2025-03-18 11:24:15 -07:00
Yash Katariya
a5c0f200e7
set_mesh
should return the prev_mesh instead of nothing. Users can choose to use the return value or ignore it.
...
PiperOrigin-RevId: 738039559
2025-03-18 09:43:25 -07:00
jax authors
7c5871f464
[Pallas TPU] Hoist prologue and epilogue outside of pipeline loop
...
PiperOrigin-RevId: 738038138
2025-03-18 09:40:43 -07:00
jax authors
30941480a1
Merge pull request #27198 from jakevdp:lax-docs
...
PiperOrigin-RevId: 738038116
2025-03-18 09:38:58 -07:00
jax authors
13541e9f12
Make blocked_fold_in consistent when the block sizes induce padding
...
Add coverage for padded shapes to unit tests.
PiperOrigin-RevId: 738029476
2025-03-18 09:12:11 -07:00
charleshofer
c46b4fc02b
Merge pull request #294 from ROCm/ci-upstream-sync-151_1
...
CI: 03/18/25 upstream sync
2025-03-18 11:10:16 -05:00
Jake VanderPlas
8b46e53a4f
jax.lax: improve docs for several APIs
2025-03-18 08:55:38 -07:00
vfdev-5
9145d617b8
Added exit 1 if git patch is failed + other checks
2025-03-18 16:47:34 +01:00
Charles Hofer
c7b407c9f0
Merge branch 'rocm-main' into ci-upstream-sync-151_1
2025-03-18 15:27:35 +00:00
Benjamin Chetioui
1e36cbe597
[Mosaic GPU] Raise a NotImplementedError
if swizzle=16
.
...
Unswizzled MMAs don't lower correctly, and are not currently intended to be
supported.
PiperOrigin-RevId: 737981373
2025-03-18 06:29:13 -07:00
Adam Paszke
8da93249d2
[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
...
This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.
PiperOrigin-RevId: 737967890
2025-03-18 05:38:49 -07:00
jax authors
7a459f0ed1
Update XLA dependency to use revision
...
3bb7654721
.
PiperOrigin-RevId: 737959582
2025-03-18 05:01:30 -07:00
Benjamin Chetioui
ba2f7c9ad9
[Mosaic GPU] Add transform inference rule for mgpu.slice_smem
.
...
PiperOrigin-RevId: 737957778
2025-03-18 04:53:54 -07:00
Adam Paszke
d4bd2570ae
[Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts
...
PiperOrigin-RevId: 737956598
2025-03-18 04:47:51 -07:00
Chris Jones
38d52a19ef
[mosaic_gpu] Force flush all cupti activity, then unsubscribe.
...
With default flushing, it is possible for events to be missed. We should only unsubscribe after we are finished with cupti.
PiperOrigin-RevId: 737939327
2025-03-18 03:35:03 -07:00
Adam Paszke
34cd5b0d74
[Mosaic GPU] Remove sub-byte conversion restriction
...
XLA:GPU recently changed its endianness to little endian to better match LLVM
and the rest of the CUDA ecosystem, so we can lift the earlier restrictions.
PiperOrigin-RevId: 737934373
2025-03-18 03:13:21 -07:00
Yash Katariya
549973dec6
Allow pspec to be passed to device_put if there is a mesh in the surrounding context
...
PiperOrigin-RevId: 737812111
2025-03-17 17:47:56 -07:00
Changhui Lin
f174b00f23
Replace the uses of PjRtClient::Compile()
with PjRtClient::CompileAndLoad()
.
...
This is to prepare for updating `PjRtClient::Compile()` to return an unloaded executable [1/N]
PiperOrigin-RevId: 737805623
2025-03-17 17:18:31 -07:00
Emily Fertig
8c35191725
Enable jax.device_put
to a sharding with no local devices.
...
PiperOrigin-RevId: 737797815
2025-03-17 16:49:46 -07:00
Sergei Lebedev
051687dc4c
[pallas] pallas_call_p
is now parameterized by a mesh
...
The mesh is necessary to add support for clusters to the Mosaic GPU backend.
PiperOrigin-RevId: 737792129
2025-03-17 16:30:40 -07:00
jax authors
b4966130a3
Compute tile index using tile-based coordinates
...
This reduces the chances of overflowing a 32-bit integer when computing tile indices.
Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`.
PiperOrigin-RevId: 737778853
2025-03-17 15:46:27 -07:00
jax authors
b74b16f9b9
Merge pull request #27164 from MichaelHudgins:a4-testing
...
PiperOrigin-RevId: 737733904
2025-03-17 13:36:04 -07:00
Michael Hudgins
ecf7fde714
Add B200 testing to continuous workflow
2025-03-17 20:19:20 +00:00
jax authors
4f70471310
Fix error in pallas tutorial
...
PiperOrigin-RevId: 737727935
2025-03-17 13:19:12 -07:00
Peter Hawkins
20658fabb3
Replace cached function get_replicated_hlo_sharding() with a constant.
...
Small cleanup, no functional changes intended.
PiperOrigin-RevId: 737727727
2025-03-17 13:17:33 -07:00
jax authors
ebcae0d30a
Merge pull request #26980 from carlosgmartin:categorical_replace
...
PiperOrigin-RevId: 737720590
2025-03-17 12:58:01 -07:00
Peter Hawkins
be5d13af77
Remove code that preserved _original_py_fns on C++ classes.
...
This no longer appears to be used.
PiperOrigin-RevId: 737715578
2025-03-17 12:43:04 -07:00