Ayaka
9b0ace4a11
Support error checking in explicit mode
...
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00
jax authors
d07d642d6f
Merge pull request #27177 from jax-ml:mixing_modes
...
PiperOrigin-RevId: 737047069
2025-03-14 18:34:27 -07:00
Yash Katariya
3c0027af3b
mixing modes
2025-03-14 18:23:27 -07:00
jax authors
7db59cdcca
Merge pull request #27174 from mattjj:opt-barrier-ad-rules
...
PiperOrigin-RevId: 737040381
2025-03-14 17:59:07 -07:00
Peter Hawkins
14cb7453f0
Add a C++ implementation of a toplogical sort.
...
This is an exact port of the current Python implementation to C++ for speed.
I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.
PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
Matthew Johnson
dadc68b6c1
add experimental lax.optimization_barrier autodiff rules
2025-03-14 22:40:55 +00:00
jax authors
b00a3a1986
Merge pull request #27015 from mattjj:direct-linearize-fixes-4
...
PiperOrigin-RevId: 737003323
2025-03-14 15:24:11 -07:00
Sergei Lebedev
64230d1c93
[pallas:mosaic_gpu] WG lowering now supports while_p
...
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
Matthew Johnson
174dcc771a
[direct-linearize] shmap fixes
2025-03-14 21:38:50 +00:00
jax authors
95791fa9e4
Merge pull request #27173 from jakevdp:fix-ipynb
...
PiperOrigin-RevId: 736987967
2025-03-14 14:34:31 -07:00
Tzu-Wei Sung
21f5f2d45e
[Pallas] Increase #rows when casting to x2.
...
There is a bug in XLA on v5p.
PiperOrigin-RevId: 736987667
2025-03-14 14:32:33 -07:00
Jake VanderPlas
412b2e3acb
Fix notebook formatting
2025-03-14 14:20:50 -07:00
Daniel Suo
39e8ee93b0
Add experimental/serialize_executable.py
to BUILD
.
...
PiperOrigin-RevId: 736975882
2025-03-14 13:54:39 -07:00
Yash Katariya
aa9480a441
Expose get_abstract_mesh
via the jax.sharding
namespace
...
PiperOrigin-RevId: 736972976
2025-03-14 13:45:32 -07:00
jax authors
a11d8891ce
Merge pull request #27165 from jax-ml:sharding-in-types-doc
...
PiperOrigin-RevId: 736971523
2025-03-14 13:40:47 -07:00
Dougal
e8f43d1cef
Explicit sharding docs
2025-03-14 16:33:30 -04:00
Justin Fu
dbd8d92075
[Pallas] Add legacy PRNG key support to Pallas PRNG
...
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
Zac Mustin
0c8e601f90
Support convolution in roofline.
...
So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity.
PiperOrigin-RevId: 736948528
2025-03-14 12:26:20 -07:00
Yash Katariya
88d4bc3d45
Rename AxisTypes enum to AxisType
...
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Emily Fertig
bdb6d03322
Allow make_array_from_callback
to construct nonaddressable arrays.
...
PiperOrigin-RevId: 736922870
2025-03-14 11:10:32 -07:00
Sergei Lebedev
97bbc37e83
[dlpack] Support more DLPack dtypes now that we target DLPack 1.1
...
I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor
TensorFlow currently support importing DLPack arrays with the new dtypes.
PiperOrigin-RevId: 736882904
2025-03-14 09:10:56 -07:00
Ilya Tikhonovskiy
c9ac82c826
[XLA:GPU] Add missing BF16_BF16_F32_X9 matmul option in config.py
...
Extend the list of possible default algorithms that dot could use.
PiperOrigin-RevId: 736879149
2025-03-14 08:58:59 -07:00
Nitin Srinivasan
5944c9ed65
Install test dependencies from test-requirements.txt instead of requirements.in
...
PiperOrigin-RevId: 736878834
2025-03-14 08:57:20 -07:00
Peter Hawkins
6fa98fc0a4
Use "x is y" rather than "id(x) == id(y)".
...
The latter involves at least two object constructions.
PiperOrigin-RevId: 736878098
2025-03-14 08:54:46 -07:00
jax authors
8fbe3b1333
Remove internal_test_util
folder and packages from jax
wheel.
...
PiperOrigin-RevId: 736861450
2025-03-14 07:52:03 -07:00
Peter Hawkins
8ab33669e2
Add a variant of safe_map() that has no return value, named foreach().
...
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.
I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.
PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Peter Hawkins
074216e07a
Precompute a weakref to a Trace≥
...
We use Trace weakrefs frequently, so we may as well construct one eagerly.
PiperOrigin-RevId: 736841778
2025-03-14 06:26:17 -07:00
jax authors
92c57a51b9
Update XLA dependency to use revision
...
4c4aa96f9f
.
PiperOrigin-RevId: 736824693
2025-03-14 05:04:35 -07:00
Benjamin Chetioui
5098d2ef49
[Mosaic GPU][NFC] Simplify implementation for in_{layout,transforms}_for_operand
utils.
...
PiperOrigin-RevId: 736809960
2025-03-14 03:52:10 -07:00
Ilya Tikhonovskiy
43b78c539f
[JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
...
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.
PiperOrigin-RevId: 736799241
2025-03-14 02:57:55 -07:00
jax authors
cbece0b00b
Add explicit support for float8_e4m3b11fnuz in pl.dot
...
PiperOrigin-RevId: 736798315
2025-03-14 02:51:55 -07:00
Benjamin Chetioui
d09df7c8ab
[Mosaic GPU] Add transform inference rules for mgpu.async_{load,store}
.
...
PiperOrigin-RevId: 736795784
2025-03-14 02:37:55 -07:00
Benjamin Chetioui
d028354abb
[Mosaic GPU] Introduce an initial transform inference pass.
...
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.
The pass isn't called anywhere yet.
PiperOrigin-RevId: 736758754
2025-03-13 23:22:59 -07:00
Emily Fertig
d79472101d
Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
...
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff
PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Yash Katariya
d3a41d8448
get_sharding
doesn't need to be conditioned on the context mesh
...
PiperOrigin-RevId: 736710468
2025-03-13 18:59:31 -07:00
Tzu-Wei Sung
e235fb9760
[Mosaic] Allow part of x2 int casts.
...
This should at least allow int2 -> int4 for native tiling vregs. Skip many tests due to XLA compatibility.
PiperOrigin-RevId: 736710186
2025-03-13 18:57:36 -07:00
Matthew Johnson
34d6bb2e16
fix shard_map manual mesh axis names with vmap spmd_axis_name
...
PiperOrigin-RevId: 736707234
2025-03-13 18:41:46 -07:00
Hyeontaek Lim
73b8f6aee2
[JAX] Clean up make_array_from_callback_* API benchmarks and add a partially replicated sharding variant
...
To prepare for the upcoming `BatchedDevicePut` implementation changes, this
change makes `make_array_from_callback_*` benchmark code to be more
homogeneous. Also it adds a variant that uses a partially replicated sharding.
PiperOrigin-RevId: 736665856
2025-03-13 15:50:46 -07:00
Yash Katariya
e615e2acb3
Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
...
Previously it was:
`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`
Now it is:
`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`
PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Peter Hawkins
1507754408
Precompute the __hash__ of AbstractMesh.
...
We use this frequently and it saves time to precompute it.
PiperOrigin-RevId: 736650750
2025-03-13 15:01:31 -07:00
jax authors
538a2be7fe
Reverts 74b4d868e3751c1b4efa315ff8cf771faeb0b663
...
PiperOrigin-RevId: 736650031
2025-03-13 14:59:09 -07:00
Zac Mustin
acd6c40f2f
Remove obsolete fallback for cost analysis.
...
This fallback does not seem to be needed as all executables have a cost-analysis implementation.
PiperOrigin-RevId: 736647203
2025-03-13 14:49:40 -07:00
Yash Katariya
e1b62cede1
Raise an error if jax.config.update('jax_num_cpu_devices', val)
is called after backend is initialized
...
PiperOrigin-RevId: 736646012
2025-03-13 14:45:53 -07:00
jax authors
47bf22e37d
[pallas][Mosaic][Easy] Add batch dot dim test, remove check
...
PiperOrigin-RevId: 736623531
2025-03-13 13:38:44 -07:00
jax authors
726f49cbca
Merge pull request #26944 from wenscarl:wenscarl/nvfp4
...
PiperOrigin-RevId: 736620378
2025-03-13 13:30:46 -07:00
Tzu-Wei Sung
a0f1be123d
[Mosaic] Improve error messages.
...
PiperOrigin-RevId: 736580673
2025-03-13 11:35:33 -07:00
jax authors
bf829ff612
Merge pull request #26524 from carlosgmartin:random_multinomial
...
PiperOrigin-RevId: 736569564
2025-03-13 11:05:17 -07:00
Peter Hawkins
8effa19734
[JAX] Change jax.core.Trace subclasses to call super().__init__().
...
Test the value of Trace._invalidated directly rather than using a hasattr test. I'm assuming the reason we did this is because we wanted to avoid updating all the subclasses to call super().__init__().
hasattr() tests are unnecessarily slow (did you know the one in jax.core.Trace builds an error message every time it fails?)
PiperOrigin-RevId: 736555016
2025-03-13 10:27:52 -07:00
Yash Katariya
14b9f48535
Allow late binding out_shardings
and in_shardings
in auto_axes
and explicit_axes
API
...
PiperOrigin-RevId: 736535562
2025-03-13 09:37:24 -07:00
Nitin Srinivasan
12760af236
Add custom job names to group different matrix combinations in the Actions dashboard
...
PiperOrigin-RevId: 736481804
2025-03-13 06:23:04 -07:00