1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-23 04:36:07 +00:00

16376 Commits

Author SHA1 Message Date
Joan Puigcerver
466ef6a132 Change the way that batching.spec_types is updated.
There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.

Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).

When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.

PiperOrigin-RevId: 737280270
2025-03-15 22:58:44 -07:00
Ayaka
9b0ace4a11 Support error checking in explicit mode
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00
jax authors
7db59cdcca Merge pull request from mattjj:opt-barrier-ad-rules
PiperOrigin-RevId: 737040381
2025-03-14 17:59:07 -07:00
Peter Hawkins
14cb7453f0 Add a C++ implementation of a toplogical sort.
This is an exact port of the current Python implementation to C++ for speed.

I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.

PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
Matthew Johnson
dadc68b6c1 add experimental lax.optimization_barrier autodiff rules 2025-03-14 22:40:55 +00:00
jax authors
b00a3a1986 Merge pull request from mattjj:direct-linearize-fixes-4
PiperOrigin-RevId: 737003323
2025-03-14 15:24:11 -07:00
Sergei Lebedev
64230d1c93 [pallas:mosaic_gpu] WG lowering now supports while_p
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
Matthew Johnson
174dcc771a [direct-linearize] shmap fixes 2025-03-14 21:38:50 +00:00
Daniel Suo
39e8ee93b0 Add experimental/serialize_executable.py to BUILD.
PiperOrigin-RevId: 736975882
2025-03-14 13:54:39 -07:00
Yash Katariya
aa9480a441 Expose get_abstract_mesh via the jax.sharding namespace
PiperOrigin-RevId: 736972976
2025-03-14 13:45:32 -07:00
Justin Fu
dbd8d92075 [Pallas] Add legacy PRNG key support to Pallas PRNG
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
Zac Mustin
0c8e601f90 Support convolution in roofline.
So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity.

PiperOrigin-RevId: 736948528
2025-03-14 12:26:20 -07:00
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Emily Fertig
bdb6d03322 Allow make_array_from_callback to construct nonaddressable arrays.
PiperOrigin-RevId: 736922870
2025-03-14 11:10:32 -07:00
Ilya Tikhonovskiy
c9ac82c826 [XLA:GPU] Add missing BF16_BF16_F32_X9 matmul option in config.py
Extend the list of possible default algorithms that dot could use.

PiperOrigin-RevId: 736879149
2025-03-14 08:58:59 -07:00
Peter Hawkins
6fa98fc0a4 Use "x is y" rather than "id(x) == id(y)".
The latter involves at least two object constructions.

PiperOrigin-RevId: 736878098
2025-03-14 08:54:46 -07:00
Peter Hawkins
8ab33669e2 Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Peter Hawkins
074216e07a Precompute a weakref to a Trace≥
We use Trace weakrefs frequently, so we may as well construct one eagerly.

PiperOrigin-RevId: 736841778
2025-03-14 06:26:17 -07:00
Benjamin Chetioui
5098d2ef49 [Mosaic GPU][NFC] Simplify implementation for in_{layout,transforms}_for_operand utils.
PiperOrigin-RevId: 736809960
2025-03-14 03:52:10 -07:00
Ilya Tikhonovskiy
43b78c539f [JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.

PiperOrigin-RevId: 736799241
2025-03-14 02:57:55 -07:00
jax authors
cbece0b00b Add explicit support for float8_e4m3b11fnuz in pl.dot
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Benjamin Chetioui
d09df7c8ab [Mosaic GPU] Add transform inference rules for mgpu.async_{load,store}.
PiperOrigin-RevId: 736795784
2025-03-14 02:37:55 -07:00
Benjamin Chetioui
d028354abb [Mosaic GPU] Introduce an initial transform inference pass.
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.

The pass isn't called anywhere yet.

PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Emily Fertig
d79472101d Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff

PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Yash Katariya
d3a41d8448 get_sharding doesn't need to be conditioned on the context mesh
PiperOrigin-RevId: 736710468
2025-03-13 18:59:31 -07:00
Matthew Johnson
34d6bb2e16 fix shard_map manual mesh axis names with vmap spmd_axis_name
PiperOrigin-RevId: 736707234
2025-03-13 18:41:46 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Peter Hawkins
1507754408 Precompute the __hash__ of AbstractMesh.
We use this frequently and it saves time to precompute it.

PiperOrigin-RevId: 736650750
2025-03-13 15:01:31 -07:00
Zac Mustin
acd6c40f2f Remove obsolete fallback for cost analysis.
This fallback does not seem to be needed as all executables have a cost-analysis implementation.

PiperOrigin-RevId: 736647203
2025-03-13 14:49:40 -07:00
Yash Katariya
e1b62cede1 Raise an error if jax.config.update('jax_num_cpu_devices', val) is called after backend is initialized
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00
jax authors
47bf22e37d [pallas][Mosaic][Easy] Add batch dot dim test, remove check
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
jax authors
726f49cbca Merge pull request from wenscarl:wenscarl/nvfp4
PiperOrigin-RevId: 736620378
2025-03-13 13:30:46 -07:00
jax authors
bf829ff612 Merge pull request from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Peter Hawkins
8effa19734 [JAX] Change jax.core.Trace subclasses to call super().__init__().
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().

hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)

PiperOrigin-RevId: 736555016
2025-03-13 10:27:52 -07:00
Yash Katariya
14b9f48535 Allow late binding out_shardings and in_shardings in auto_axes and explicit_axes API
PiperOrigin-RevId: 736535562
2025-03-13 09:37:24 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
Yash Katariya
47480b4493 Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
2025-03-12 14:12:47 -07:00
Yash Katariya
8674495fd7 [sharding_in_types] Make reshard work with np.array.
PiperOrigin-RevId: 736250504
2025-03-12 13:41:42 -07:00
Justin Fu
6978f35293 [Pallas] Plumb compiler flags through source mapper.
PiperOrigin-RevId: 736199966
2025-03-12 11:19:58 -07:00
Christos Perivolaropoulos
b34f56bfd7 [mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
jax authors
3de7ecf6da Merge pull request from pearu:pearu/gammainc-bug-fix
PiperOrigin-RevId: 736177398
2025-03-12 10:20:39 -07:00
jax authors
e7d10a2310 Merge pull request from carlosgmartin:fix_binomial_value_error
PiperOrigin-RevId: 736171463
2025-03-12 10:05:18 -07:00
Pearu Peterson
f608a8c502 Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan. 2025-03-12 17:48:45 +02:00
Yash Katariya
abcc7fdf4c [sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName] to ShapedArray. Also add jax_varying_axes_in_types config to hide this option under while we develop it.
PiperOrigin-RevId: 736141670
2025-03-12 08:29:16 -07:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Matthew Johnson
66a6eb299e add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.

PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00