16405 Commits

Author SHA1 Message Date
Benjamin Chetioui
d09df7c8ab [Mosaic GPU] Add transform inference rules for mgpu.async_{load,store}.
PiperOrigin-RevId: 736795784
2025-03-14 02:37:55 -07:00
Benjamin Chetioui
d028354abb [Mosaic GPU] Introduce an initial transform inference pass.
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.

The pass isn't called anywhere yet.

PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Emily Fertig
d79472101d Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff

PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Yash Katariya
d3a41d8448 get_sharding doesn't need to be conditioned on the context mesh
PiperOrigin-RevId: 736710468
2025-03-13 18:59:31 -07:00
Matthew Johnson
34d6bb2e16 fix shard_map manual mesh axis names with vmap spmd_axis_name
PiperOrigin-RevId: 736707234
2025-03-13 18:41:46 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Peter Hawkins
1507754408 Precompute the __hash__ of AbstractMesh.
We use this frequently and it saves time to precompute it.

PiperOrigin-RevId: 736650750
2025-03-13 15:01:31 -07:00
Zac Mustin
acd6c40f2f Remove obsolete fallback for cost analysis.
This fallback does not seem to be needed as all executables have a cost-analysis implementation.

PiperOrigin-RevId: 736647203
2025-03-13 14:49:40 -07:00
Yash Katariya
e1b62cede1 Raise an error if jax.config.update('jax_num_cpu_devices', val) is called after backend is initialized
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00
jax authors
47bf22e37d [pallas][Mosaic][Easy] Add batch dot dim test, remove check
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
jax authors
726f49cbca Merge pull request #26944 from wenscarl:wenscarl/nvfp4
PiperOrigin-RevId: 736620378
2025-03-13 13:30:46 -07:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Peter Hawkins
8effa19734 [JAX] Change jax.core.Trace subclasses to call super().__init__().
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().

hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)

PiperOrigin-RevId: 736555016
2025-03-13 10:27:52 -07:00
Yash Katariya
14b9f48535 Allow late binding out_shardings and in_shardings in auto_axes and explicit_axes API
PiperOrigin-RevId: 736535562
2025-03-13 09:37:24 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
Yash Katariya
47480b4493 Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
2025-03-12 14:12:47 -07:00
Yash Katariya
8674495fd7 [sharding_in_types] Make reshard work with np.array.
PiperOrigin-RevId: 736250504
2025-03-12 13:41:42 -07:00
Justin Fu
6978f35293 [Pallas] Plumb compiler flags through source mapper.
PiperOrigin-RevId: 736199966
2025-03-12 11:19:58 -07:00
Christos Perivolaropoulos
b34f56bfd7 [mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
jax authors
3de7ecf6da Merge pull request #27092 from pearu:pearu/gammainc-bug-fix
PiperOrigin-RevId: 736177398
2025-03-12 10:20:39 -07:00
jax authors
e7d10a2310 Merge pull request #27041 from carlosgmartin:fix_binomial_value_error
PiperOrigin-RevId: 736171463
2025-03-12 10:05:18 -07:00
Pearu Peterson
f608a8c502 Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan. 2025-03-12 17:48:45 +02:00
Yash Katariya
abcc7fdf4c [sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName] to ShapedArray. Also add jax_varying_axes_in_types config to hide this option under while we develop it.
PiperOrigin-RevId: 736141670
2025-03-12 08:29:16 -07:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Matthew Johnson
66a6eb299e add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.

PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
Sharad Vikram
c6b164dc09 [Pallas/Fuser] Add custom evaluate to allow/disallow transposes
PiperOrigin-RevId: 735931978
2025-03-11 16:35:49 -07:00
Yash Katariya
f45cbf3342 Fix a bug where full and use_mesh outside jit did not work because the shard passed to make_array_from_callback was sharded on all devices instead of just 1 device.
This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context.

PiperOrigin-RevId: 735909962
2025-03-11 15:25:46 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
jax authors
7ac088c14f Merge pull request #20699 from pearu:pearu/gammainc
PiperOrigin-RevId: 735878582
2025-03-11 13:53:20 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Jevin Jiang
eff612a3b6 Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
PiperOrigin-RevId: 735851301
2025-03-11 12:36:33 -07:00
shuw
f9aef8a189 Support nvfp4 2025-03-11 19:33:25 +00:00
Pearu Peterson
82b2591b21 Fix scipy.special.gammainc/gammaincc evaluation at boundary points 2025-03-11 21:18:47 +02:00
jax authors
c2c68c018f Merge pull request #27059 from jakevdp:fix-while-loop
PiperOrigin-RevId: 735828960
2025-03-11 11:32:00 -07:00
Gunhyun Park
d191927b24 Fix syntax error and typos for composite primitive docstring.
PiperOrigin-RevId: 735808000
2025-03-11 10:37:07 -07:00
Jake VanderPlas
4ae3211ea2 jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version 2025-03-11 09:53:34 -07:00
Adam Paszke
30a9e1b3bf [Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
This one is particularly annoying, because we have to break up the MMA
into two collective N=256 MMAs. However, TensorCore only updates a contiguous
chunk of columns in TMEM and so after executing two of those we end up with
a TMEM layout that looks like this:

```
Contributing CTA |    0    |    1    |    0    |    1    |
N local          |   0:128 |   0:128 | 128:256 | 128:256 |
N                |   0:128 | 256:384 | 128:256 | 384:512 |
```

You can see that the TMEM columns no longer monotonically go over all
columns until N=512, but they include a number of jumps!

We could fix this on the load side, by ensuring that each CTA in the group
does a strided load along the tiled dimension, but that just seems more
trouble than it's worth (and is not that well supported by TMA unless we
increase the number of striding levels).

Instead, we encode this weirdness in the TMEM layout we use and make sure
to rearrange the data properly while loading the tiles into registers.

PiperOrigin-RevId: 735791426
2025-03-11 09:53:20 -07:00
jax authors
1aca76fc13 Update :build_jaxlib flag to control whether we should add py_import dependencies to the test targets.
This change enables testing the wheels produced by the build rules in the presubmit using one `bazel test` command only.

There are three options for running the tests:

1) `build_jaxlib=true`: the tests depend on JAX targets.
2) `build_jaxlib=false`: the tests depend on the wheel files located in the `dist` folder.
3) `build_jaxlib=wheel`: the tests depend on the py_import targets.

PiperOrigin-RevId: 735765819
2025-03-11 08:31:43 -07:00
Yash Katariya
76dec38286 Under pjit the with mesh: context will use use_mesh(mesh): jit instead of tracking separately using resource_env.
This would also make it easier to deprecate the `with mesh: pjit` path in the future from user code since the new path would be completely tested.
This will also allow us to remove `resource_env` from JAX and the internal API access of `resource_env.physical_mesh` spread throughout codebases internally and externally.

PiperOrigin-RevId: 735602187
2025-03-10 20:21:02 -07:00
jax authors
02505fa757 [Pallas TPU] Remove next_slot SMEM tensor from pipeline emitter
PiperOrigin-RevId: 735564365
2025-03-10 17:19:39 -07:00
Ayaka
988a1208a9 Better error message when raise_if_error() is called within a traced context
PiperOrigin-RevId: 735557928
2025-03-10 16:55:06 -07:00
jax authors
aceae84fab [Pallas] Enable skipping of floating-point operations when interpreting Pallas TPU kernels on CPU.
PiperOrigin-RevId: 735527650
2025-03-10 15:14:00 -07:00
Sharad Vikram
81dde225b0 [Pallas/Fuser] Add select_n push rule
PiperOrigin-RevId: 735510713
2025-03-10 14:23:01 -07:00
jax authors
261e6e5fdc Merge pull request #27038 from jakevdp:vmap-sentinel
PiperOrigin-RevId: 735510065
2025-03-10 14:21:11 -07:00
jax authors
c942b0fef0 Merge pull request #26977 from jakevdp:fix-expn
PiperOrigin-RevId: 735506133
2025-03-10 14:09:32 -07:00
Sharad Vikram
87272fbe93 [Pallas/Fuser] Add debug option to fuser.fuse that prints out jaxpr
PiperOrigin-RevId: 735505460
2025-03-10 14:07:26 -07:00