26204 Commits

Author SHA1 Message Date
jax authors
47bf22e37d [pallas][Mosaic][Easy] Add batch dot dim test, remove check
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
jax authors
726f49cbca Merge pull request #26944 from wenscarl:wenscarl/nvfp4
PiperOrigin-RevId: 736620378
2025-03-13 13:30:46 -07:00
Tzu-Wei Sung
a0f1be123d [Mosaic] Improve error messages.
PiperOrigin-RevId: 736580673
2025-03-13 11:35:33 -07:00
jax authors
bf829ff612 Merge pull request #26524 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Peter Hawkins
8effa19734 [JAX] Change jax.core.Trace subclasses to call super().__init__().
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().

hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)

PiperOrigin-RevId: 736555016
2025-03-13 10:27:52 -07:00
Yash Katariya
14b9f48535 Allow late binding out_shardings and in_shardings in auto_axes and explicit_axes API
PiperOrigin-RevId: 736535562
2025-03-13 09:37:24 -07:00
Nitin Srinivasan
12760af236 Add custom job names to group different matrix combinations in the Actions dashboard
PiperOrigin-RevId: 736481804
2025-03-13 06:23:04 -07:00
jax authors
c07d839e6e Update XLA dependency to use revision
45e12978d5.

PiperOrigin-RevId: 736452873
2025-03-13 04:09:44 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
carlosgmartin
6b69a136aa Add jax.random.multinomial. 2025-03-12 18:15:14 -04:00
jax authors
ba367cdead Merge pull request #27044 from carlosgmartin:add_breadcrumbs_to_docs
PiperOrigin-RevId: 736283028
2025-03-12 15:14:17 -07:00
Jevin Jiang
12c0987e2f [Mosaic TPU][NFC] Throw NYI error instead of crash when squeeze ref to 1d.
PiperOrigin-RevId: 736263705
2025-03-12 14:18:33 -07:00
Yash Katariya
47480b4493 Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
2025-03-12 14:12:47 -07:00
carlosgmartin
bc43b00d8f Add navigation breadcrumbs to docs. 2025-03-12 16:52:34 -04:00
Yash Katariya
8674495fd7 [sharding_in_types] Make reshard work with np.array.
PiperOrigin-RevId: 736250504
2025-03-12 13:41:42 -07:00
Justin Fu
6978f35293 [Pallas] Plumb compiler flags through source mapper.
PiperOrigin-RevId: 736199966
2025-03-12 11:19:58 -07:00
Christos Perivolaropoulos
b34f56bfd7 [mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
jax authors
3de7ecf6da Merge pull request #27092 from pearu:pearu/gammainc-bug-fix
PiperOrigin-RevId: 736177398
2025-03-12 10:20:39 -07:00
jax authors
e7d10a2310 Merge pull request #27041 from carlosgmartin:fix_binomial_value_error
PiperOrigin-RevId: 736171463
2025-03-12 10:05:18 -07:00
Pearu Peterson
f608a8c502 Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan. 2025-03-12 17:48:45 +02:00
Yash Katariya
abcc7fdf4c [sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName] to ShapedArray. Also add jax_varying_axes_in_types config to hide this option under while we develop it.
PiperOrigin-RevId: 736141670
2025-03-12 08:29:16 -07:00
Dan Foreman-Mackey
8b7cfcb33c Fix integer overflow in workspace size computations for experimental.rnn.*.
PiperOrigin-RevId: 736139471
2025-03-12 08:22:04 -07:00
Sergei Lebedev
e33f3fc48b [pallas:mosaic_gpu] Added support for reductions to the WG lowering
Note that

* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
  the dialect lowering AFAICT.

PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Nitin Srinivasan
d89835acba Fix matrix exclude syntax in TPU tests block
Also, skip Python 3.13 for now due to missing dependency error.

PiperOrigin-RevId: 736120590
2025-03-12 07:12:52 -07:00
Nitin Srinivasan
a6ab6bbc20 Ignore Pallas TPU tests when testing with the oldest supported libtpu
I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows

PiperOrigin-RevId: 736094492
2025-03-12 05:20:42 -07:00
jax authors
61ba2b2603 Update XLA dependency to use revision
c270a6ce45.

PiperOrigin-RevId: 736088162
2025-03-12 04:52:20 -07:00
Chris Jones
74b4d868e3 Add support for scratch buffers in jax_triton.
This is required to use device-side TMA descriptors.

PiperOrigin-RevId: 735985603
2025-03-11 20:49:33 -07:00
Nitin Srinivasan
ff751ecc7b Run single python version for v4-8 and min & max for v5e-8 for TPU tests in nightly/release test workflow
PiperOrigin-RevId: 735975004
2025-03-11 20:03:05 -07:00
Matthew Johnson
66a6eb299e add autodiff rules for jax.lax.ragged_all_to_all collective
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.

PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
Sharad Vikram
c6b164dc09 [Pallas/Fuser] Add custom evaluate to allow/disallow transposes
PiperOrigin-RevId: 735931978
2025-03-11 16:35:49 -07:00
Yash Katariya
f45cbf3342 Fix a bug where full and use_mesh outside jit did not work because the shard passed to make_array_from_callback was sharded on all devices instead of just 1 device.
This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context.

PiperOrigin-RevId: 735909962
2025-03-11 15:25:46 -07:00
Jevin Jiang
29bfd00f9c [Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
2025-03-11 15:06:07 -07:00
jax authors
13eb8d3ae7 Upgrade ml-dtypes version in py3.10-py3.13 hermetic python lock files.
This change is needed to add testing of int2/uint2 dtypes via bazel in presubmit (see https://github.com/jax-ml/jax/pull/21395).

PiperOrigin-RevId: 735895293
2025-03-11 14:41:34 -07:00
Kanglan Tang
4df691ec00 Remove unsupported mac x86 CI build options
PiperOrigin-RevId: 735885305
2025-03-11 14:12:51 -07:00
jax authors
7ac088c14f Merge pull request #20699 from pearu:pearu/gammainc
PiperOrigin-RevId: 735878582
2025-03-11 13:53:20 -07:00
Dimitar (Mitko) Asenov
99c9106032 [Mosaic GPU] Replace WGMMAFragLayout with TiledLayout in the mlir dialect and use it in layout inference.
`WGMMAFragLayout` will be completely removed soon.

PiperOrigin-RevId: 735877661
2025-03-11 13:50:42 -07:00
Peter Hawkins
67aa997f84 Increase the number of iterations in a test that compares rolled versus unrolled HLO for length.
A change that avoids duplicating subcomputations in XLA causes this test to fail, but we can make it work again by increasing the number of iterations.

PiperOrigin-RevId: 735875835
2025-03-11 13:45:19 -07:00
jax authors
e0545a71eb Remove installation of NVIDIA wheels for CPU tests
PiperOrigin-RevId: 735875073
2025-03-11 13:43:13 -07:00
Jevin Jiang
eff612a3b6 Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
PiperOrigin-RevId: 735851301
2025-03-11 12:36:33 -07:00
jax authors
0db14aa342 Add NVIDIA wheel requirements only for Linux builds.
PiperOrigin-RevId: 735850240
2025-03-11 12:33:54 -07:00
shuw
f9aef8a189 Support nvfp4 2025-03-11 19:33:25 +00:00
Pearu Peterson
82b2591b21 Fix scipy.special.gammainc/gammaincc evaluation at boundary points 2025-03-11 21:18:47 +02:00
Nitin Srinivasan
7ac6355262 Add TPU test jobs to the new CI continuous and nightly/release test workflows
Also, modify the TPU presubmit workflow to reuse the `build_artifacts.yml` and `pytest_tpu.yml`

PiperOrigin-RevId: 735832964
2025-03-11 11:42:21 -07:00
jax authors
c2c68c018f Merge pull request #27059 from jakevdp:fix-while-loop
PiperOrigin-RevId: 735828960
2025-03-11 11:32:00 -07:00
Gunhyun Park
d191927b24 Fix syntax error and typos for composite primitive docstring.
PiperOrigin-RevId: 735808000
2025-03-11 10:37:07 -07:00
Adam Paszke
6f7ce9d048 Skip ASAN tests for the big Mosaic GPU tests
They are timing out.

PiperOrigin-RevId: 735804647
2025-03-11 10:30:04 -07:00
Jake VanderPlas
4ae3211ea2 jax.disable_jit: ensure while_loop behaves similarly to non-disable_jit version 2025-03-11 09:53:34 -07:00