9810 Commits

Author SHA1 Message Date
charleshofer
13d88b6340
Add back raw totals in JSON reports (#281) 2025-03-24 11:26:31 -05:00
rocm-repo-management-api-2[bot]
b505df9973
Merge pull request #299 from ROCm/ci-upstream-sync-152_1
CI: 03/19/25 upstream sync
2025-03-19 07:20:19 -05:00
jax authors
e9ce8fb92d Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
PiperOrigin-RevId: 738235363
2025-03-18 20:22:27 -07:00
jax authors
01a110c4c9 Better mosaic lowering for dynamic shapes, extend an interpreter into shape_poly dimexpr and lower them alongside the graph if we are in a dynamic export regime.
PiperOrigin-RevId: 738171437
2025-03-18 15:51:15 -07:00
Parker Schuh
0fb59747f0 Support tuples in custom_partitioning.
PiperOrigin-RevId: 738154413
2025-03-18 14:57:08 -07:00
jax authors
080804c78d Fix logging_test fails on Linux with NVIDIA Driver only.
Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux).

Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed.

Solution: Save python program to file in runtime directory then run script from the file.
PiperOrigin-RevId: 738152663
2025-03-18 14:51:35 -07:00
Gleb Pobudzey
54691b125a [Mosaic GPU] Support reads/writes from SMEM to WGMMARowFragLayout arrays.
PiperOrigin-RevId: 738121106
2025-03-18 13:23:07 -07:00
Yash Katariya
76d9890bb7 Run the stream annotation tests on 2 devices so that it can be tested in TAP
PiperOrigin-RevId: 738113725
2025-03-18 13:01:48 -07:00
Jacob Burnim
47e8effdce Adds option to initialize buffers to NaNs or zeros in TPU interpret mode. 2025-03-18 12:24:45 -07:00
Benjamin Chetioui
875099b25d [Mosaic GPU] Enable the new transform inference pass in the warpgroup lowering.
A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.

PiperOrigin-RevId: 738089782
2025-03-18 11:51:43 -07:00
Yash Katariya
a5c0f200e7 set_mesh should return the prev_mesh instead of nothing. Users can choose to use the return value or ignore it.
PiperOrigin-RevId: 738039559
2025-03-18 09:43:25 -07:00
jax authors
13541e9f12 Make blocked_fold_in consistent when the block sizes induce padding
Add coverage for padded shapes to unit tests.

PiperOrigin-RevId: 738029476
2025-03-18 09:12:11 -07:00
Charles Hofer
c7b407c9f0 Merge branch 'rocm-main' into ci-upstream-sync-151_1 2025-03-18 15:27:35 +00:00
Benjamin Chetioui
ba2f7c9ad9 [Mosaic GPU] Add transform inference rule for mgpu.slice_smem.
PiperOrigin-RevId: 737957778
2025-03-18 04:53:54 -07:00
Adam Paszke
d4bd2570ae [Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGMMA friendly layouts
PiperOrigin-RevId: 737956598
2025-03-18 04:47:51 -07:00
Adam Paszke
34cd5b0d74 [Mosaic GPU] Remove sub-byte conversion restriction
XLA:GPU recently changed its endianness to little endian to better match LLVM
and the rest of the CUDA ecosystem, so we can lift the earlier restrictions.
PiperOrigin-RevId: 737934373
2025-03-18 03:13:21 -07:00
Yash Katariya
549973dec6 Allow pspec to be passed to device_put if there is a mesh in the surrounding context
PiperOrigin-RevId: 737812111
2025-03-17 17:47:56 -07:00
jax authors
b4966130a3 Compute tile index using tile-based coordinates
This reduces the chances of overflowing a 32-bit integer when computing tile indices.
Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`.

PiperOrigin-RevId: 737778853
2025-03-17 15:46:27 -07:00
jax authors
ebcae0d30a Merge pull request #26980 from carlosgmartin:categorical_replace
PiperOrigin-RevId: 737720590
2025-03-17 12:58:01 -07:00
Benjamin Chetioui
9a686e0bf3 [Mosaic GPU] Add initial transform inference rules for vector.{load,store}.
PiperOrigin-RevId: 737703568
2025-03-17 12:08:07 -07:00
carlosgmartin
3f59fa6888 Add replace option to random.categorical to enable sampling without replacement. 2025-03-17 13:41:46 -04:00
Adam Paszke
3649da56fc [Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length
We can now perform the conversion in groups of 2, 4 or even 8 elements at a time.

PiperOrigin-RevId: 737626600
2025-03-17 08:37:17 -07:00
Sergei Lebedev
0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
Sergei Lebedev
a7e5eaee56 [pallas:mosaic_gpu] jnp.reduce_sum now works for >1D arrays
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
Adam Paszke
89b21de62a [Mosaic GPU] Add support for changing the layout before the upcast
This lets us save on 2 ALU instructions (3x select becomes 1x prmt).

PiperOrigin-RevId: 737550598
2025-03-17 03:26:48 -07:00
Adam Paszke
2bdd9c8797 [Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16-bit upcast
PiperOrigin-RevId: 737542885
2025-03-17 02:52:16 -07:00
Joan Puigcerver
466ef6a132 Change the way that batching.spec_types is updated.
There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.

Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).

When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.

PiperOrigin-RevId: 737280270
2025-03-15 22:58:44 -07:00
Ayaka
9b0ace4a11 Support error checking in explicit mode
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00
jax authors
7db59cdcca Merge pull request #27174 from mattjj:opt-barrier-ad-rules
PiperOrigin-RevId: 737040381
2025-03-14 17:59:07 -07:00
Peter Hawkins
14cb7453f0 Add a C++ implementation of a toplogical sort.
This is an exact port of the current Python implementation to C++ for speed.

I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.

PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
GitHub Actions
e275d5cf6c Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-147_1 2025-03-14 22:42:07 +00:00
Matthew Johnson
dadc68b6c1 add experimental lax.optimization_barrier autodiff rules 2025-03-14 22:40:55 +00:00
Sergei Lebedev
64230d1c93 [pallas:mosaic_gpu] WG lowering now supports while_p
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
charleshofer
022da913e6
Count test totals correctly for dashboards (#280)
* Account test totals correctly for dashboards

* Add blurb to the dev guide on skipping tests

* Remove extra newline

* Default to 0 if "skipped" isn't found

Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>

---------

Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
2025-03-14 16:57:08 -05:00
Tzu-Wei Sung
21f5f2d45e [Pallas] Increase #rows when casting to x2.
There is a bug in XLA on v5p.

PiperOrigin-RevId: 736987667
2025-03-14 14:32:33 -07:00
Justin Fu
dbd8d92075 [Pallas] Add legacy PRNG key support to Pallas PRNG
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
Zac Mustin
0c8e601f90 Support convolution in roofline.
So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity.

PiperOrigin-RevId: 736948528
2025-03-14 12:26:20 -07:00
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Emily Fertig
bdb6d03322 Allow make_array_from_callback to construct nonaddressable arrays.
PiperOrigin-RevId: 736922870
2025-03-14 11:10:32 -07:00
Sergei Lebedev
97bbc37e83 [dlpack] Support more DLPack dtypes now that we target DLPack 1.1
I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor
TensorFlow currently support importing DLPack arrays with the new dtypes.

PiperOrigin-RevId: 736882904
2025-03-14 09:10:56 -07:00
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Ilya Tikhonovskiy
43b78c539f [JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.

PiperOrigin-RevId: 736799241
2025-03-14 02:57:55 -07:00
jax authors
cbece0b00b Add explicit support for float8_e4m3b11fnuz in pl.dot
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Benjamin Chetioui
d09df7c8ab [Mosaic GPU] Add transform inference rules for mgpu.async_{load,store}.
PiperOrigin-RevId: 736795784
2025-03-14 02:37:55 -07:00
Benjamin Chetioui
d028354abb [Mosaic GPU] Introduce an initial transform inference pass.
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.

The pass isn't called anywhere yet.

PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Emily Fertig
d79472101d Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff

PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Tzu-Wei Sung
e235fb9760 [Mosaic] Allow part of x2 int casts.
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.

PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Yash Katariya
e1b62cede1 Raise an error if jax.config.update('jax_num_cpu_devices', val) is called after backend is initialized
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00
jax authors
47bf22e37d [pallas][Mosaic][Easy] Add batch dot dim test, remove check
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00