Ilya Tikhonovskiy
43b78c539f
[JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
...
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.
PiperOrigin-RevId: 736799241
2025-03-14 02:57:55 -07:00
jax authors
cbece0b00b
Add explicit support for float8_e4m3b11fnuz in pl.dot
...
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Benjamin Chetioui
d09df7c8ab
[Mosaic GPU] Add transform inference rules for mgpu.async_{load,store}
.
...
PiperOrigin-RevId: 736795784
2025-03-14 02:37:55 -07:00
Benjamin Chetioui
d028354abb
[Mosaic GPU] Introduce an initial transform inference pass.
...
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.
The pass isn't called anywhere yet.
PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Emily Fertig
d79472101d
Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
...
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff
PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Yash Katariya
d3a41d8448
get_sharding
doesn't need to be conditioned on the context mesh
...
PiperOrigin-RevId: 736710468
2025-03-13 18:59:31 -07:00
Tzu-Wei Sung
e235fb9760
[Mosaic] Allow part of x2 int casts.
...
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.
PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
Matthew Johnson
34d6bb2e16
fix shard_map manual mesh axis names with vmap spmd_axis_name
...
PiperOrigin-RevId: 736707234
2025-03-13 18:41:46 -07:00
Hyeontaek Lim
73b8f6aee2
[JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant
...
To prepare for the upcoming `BatchedDevicePut` implementation changes, this
change makes `make_array_from_callback_*` benchmark code to be more
homogeneous. Also it adds a variant that uses a partially replicated sharding.
PiperOrigin-RevId: 736665856
2025-03-13 15:50:46 -07:00
Yash Katariya
e615e2acb3
Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
...
Previously it was:
`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`
Now it is:
`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`
PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Peter Hawkins
1507754408
Precompute the __hash__ of AbstractMesh.
...
We use this frequently and it saves time to precompute it.
PiperOrigin-RevId: 736650750
2025-03-13 15:01:31 -07:00
jax authors
538a2be7fe
Reverts 74b4d868e3751c1b4efa315ff8cf771faeb0b663
...
PiperOrigin-RevId: 736650031
2025-03-13 14:59:09 -07:00
Zac Mustin
acd6c40f2f
Remove obsolete fallback for cost analysis.
...
This fallback does not seem to be needed as all executables have a cost-analysis implementation.
PiperOrigin-RevId: 736647203
2025-03-13 14:49:40 -07:00
Yash Katariya
e1b62cede1
Raise an error if jax.config.update('jax_num_cpu_devices', val)
is called after backend is initialized
...
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00
jax authors
47bf22e37d
[pallas][Mosaic][Easy] Add batch dot dim test, remove check
...
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
jax authors
726f49cbca
Merge pull request #26944 from wenscarl:wenscarl/nvfp4
...
PiperOrigin-RevId: 736620378
2025-03-13 13:30:46 -07:00
Tzu-Wei Sung
a0f1be123d
[Mosaic] Improve error messages.
...
PiperOrigin-RevId: 736580673
2025-03-13 11:35:33 -07:00
jax authors
bf829ff612
Merge pull request #26524 from carlosgmartin:random_multinomial
...
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Peter Hawkins
8effa19734
[JAX] Change jax.core.Trace subclasses to call super().__init__().
...
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().
hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)
PiperOrigin-RevId: 736555016
2025-03-13 10:27:52 -07:00
Yash Katariya
14b9f48535
Allow late binding out_shardings
and in_shardings
in auto_axes
and explicit_axes
API
...
PiperOrigin-RevId: 736535562
2025-03-13 09:37:24 -07:00
Nitin Srinivasan
12760af236
Add custom job names to group different matrix combinations in the Actions dashboard
...
PiperOrigin-RevId: 736481804
2025-03-13 06:23:04 -07:00
jax authors
c07d839e6e
Update XLA dependency to use revision
...
45e12978d5
.
PiperOrigin-RevId: 736452873
2025-03-13 04:09:44 -07:00
Yash Katariya
2d01226b3b
Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
...
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
Yash Katariya
a4ca0dbc6c
Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types)
instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types)
so that we are consistent across all Mesh APIs: Mesh
, AbstractMesh
and make_mesh
...
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759
[sharding_in_types] Rework the axis_types
argument in Mesh and AbstractMesh APIs. The changes are:
...
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore
2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.
PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
carlosgmartin
6b69a136aa
Add jax.random.multinomial.
2025-03-12 18:15:14 -04:00
jax authors
ba367cdead
Merge pull request #27044 from carlosgmartin:add_breadcrumbs_to_docs
...
PiperOrigin-RevId: 736283028
2025-03-12 15:14:17 -07:00
Jevin Jiang
12c0987e2f
[Mosaic TPU][NFC] Throw NYI error instead of crash when squeeze ref to 1d.
...
PiperOrigin-RevId: 736263705
2025-03-12 14:18:33 -07:00
Yash Katariya
47480b4493
Add a set_mesh API to jax.sharding
. set_mesh
sets the sharding and never unsets it i.e. this is just __enter__
of a ctx manager without __exit__
...
PiperOrigin-RevId: 736261724
2025-03-12 14:12:47 -07:00
carlosgmartin
bc43b00d8f
Add navigation breadcrumbs to docs.
2025-03-12 16:52:34 -04:00
Yash Katariya
8674495fd7
[sharding_in_types] Make reshard
work with np.array.
...
PiperOrigin-RevId: 736250504
2025-03-12 13:41:42 -07:00
github-actions[bot]
9cc545254c
Merge pull request #276 from ROCm/ci-upstream-sync-144_1
...
CI: 03/12/25 upstream sync
2025-03-12 13:23:13 -05:00
Justin Fu
6978f35293
[Pallas] Plumb compiler flags through source mapper.
...
PiperOrigin-RevId: 736199966
2025-03-12 11:19:58 -07:00
Christos Perivolaropoulos
b34f56bfd7
[mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
...
PiperOrigin-RevId: 736187550
2025-03-12 10:47:48 -07:00
jax authors
3de7ecf6da
Merge pull request #27092 from pearu:pearu/gammainc-bug-fix
...
PiperOrigin-RevId: 736177398
2025-03-12 10:20:39 -07:00
Charles Hofer
db8ba1b598
Change to run CI
2025-03-12 17:06:35 +00:00
jax authors
e7d10a2310
Merge pull request #27041 from carlosgmartin:fix_binomial_value_error
...
PiperOrigin-RevId: 736171463
2025-03-12 10:05:18 -07:00
GitHub Actions
a0edd3fbb2
Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-144_1
2025-03-12 16:57:18 +00:00
charleshofer
f14a1d0b71
Add JSON output to multi-GPU tests ( #274 )
2025-03-12 11:30:55 -05:00
Pearu Peterson
f608a8c502
Update gammainc and gammaincc against scipy 1.16: return nan whenever one of operands is nan.
2025-03-12 17:48:45 +02:00
Yash Katariya
abcc7fdf4c
[sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName]
to ShapedArray. Also add jax_varying_axes_in_types
config to hide this option under while we develop it.
...
PiperOrigin-RevId: 736141670
2025-03-12 08:29:16 -07:00
Dan Foreman-Mackey
8b7cfcb33c
Fix integer overflow in workspace size computations for experimental.rnn.*.
...
PiperOrigin-RevId: 736139471
2025-03-12 08:22:04 -07:00
Sergei Lebedev
e33f3fc48b
[pallas:mosaic_gpu] Added support for reductions to the WG lowering
...
Note that
* we have no easy way of testing multi-reductions at the moment;
* `reduce_max` assumes WGMMA_ROW layout which is not currently supported by
the dialect lowering AFAICT.
PiperOrigin-RevId: 736138554
2025-03-12 08:18:31 -07:00
Nitin Srinivasan
d89835acba
Fix matrix exclude syntax in TPU tests block
...
Also, skip Python 3.13 for now due to missing dependency error.
PiperOrigin-RevId: 736120590
2025-03-12 07:12:52 -07:00
Nitin Srinivasan
a6ab6bbc20
Ignore Pallas TPU tests when testing with the oldest supported libtpu
...
I missed adding this in from https://github.com/jax-ml/jax/blob/main/.github/workflows/cloud-tpu-ci-nightly.yml when I added the TPU jobs to the new CI workflows
PiperOrigin-RevId: 736094492
2025-03-12 05:20:42 -07:00
jax authors
61ba2b2603
Update XLA dependency to use revision
...
c270a6ce45
.
PiperOrigin-RevId: 736088162
2025-03-12 04:52:20 -07:00
Chris Jones
74b4d868e3
Add support for scratch buffers in jax_triton
.
...
This is required to use device-side TMA descriptors.
PiperOrigin-RevId: 735985603
2025-03-11 20:49:33 -07:00
Nitin Srinivasan
ff751ecc7b
Run single python version for v4-8 and min & max for v5e-8 for TPU tests in nightly/release test workflow
...
PiperOrigin-RevId: 735975004
2025-03-11 20:03:05 -07:00
Matthew Johnson
66a6eb299e
add autodiff rules for jax.lax.ragged_all_to_all collective
...
also update the ragged_all_to_all docstring. pseudocode in the style of the shard_map tutorial would be better and cleaner, but it needs the context of the tutorial to explain; i'll add ra2a to the shmap tutorial in the future.
PiperOrigin-RevId: 735957604
2025-03-11 18:22:02 -07:00
Yash Katariya
3a26804c68
Rename get_ty
to typeof
which is an alias of get_aval
...
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00