Benjamin Chetioui
9a686e0bf3
[Mosaic GPU] Add initial transform inference rules for vector.{load,store}
.
...
PiperOrigin-RevId: 737703568
2025-03-17 12:08:07 -07:00
carlosgmartin
3f59fa6888
Add replace option to random.categorical to enable sampling without replacement.
2025-03-17 13:41:46 -04:00
Mathew Odden
d864b4fbf4
Fix auditwheel version issue ( #288 )
...
Auditwheel 6.3.0 changed/removed the lddtree function
so cap constraint to 6.2.x
2025-03-17 12:30:56 -05:00
jax authors
de9ad6bad9
Merge pull request #27157 from mar-muel:improve-random-choice-performance
...
PiperOrigin-RevId: 737665351
2025-03-17 10:30:15 -07:00
Nitin Srinivasan
031614c22b
Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt
...
PiperOrigin-RevId: 737632771
2025-03-17 08:59:06 -07:00
Adam Paszke
3649da56fc
[Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes to vector length
...
We can now perform the conversion in groups of 2, 4 or even 8 elements at a time.
PiperOrigin-RevId: 737626600
2025-03-17 08:37:17 -07:00
Sergei Lebedev
0ff234049b
Removed trivial docstrings from JAX tests
...
These docstrings do not make the tests any more clear and typically just duplicate the test module name.
PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00
jax authors
55812c5d02
Update XLA dependency to use revision
...
fcf97e619e
.
PiperOrigin-RevId: 737581187
2025-03-17 05:44:17 -07:00
Sergei Lebedev
a7e5eaee56
[pallas:mosaic_gpu] jnp.reduce_sum
now works for >1D arrays
...
PiperOrigin-RevId: 737578598
2025-03-17 05:32:07 -07:00
Adam Paszke
89b21de62a
[Mosaic GPU] Add support for changing the layout before the upcast
...
This lets us save on 2 ALU instructions (3x select becomes 1x prmt).
PiperOrigin-RevId: 737550598
2025-03-17 03:26:48 -07:00
Adam Paszke
2bdd9c8797
[Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16-bit upcast
...
PiperOrigin-RevId: 737542885
2025-03-17 02:52:16 -07:00
jax authors
761b35c59e
Merge pull request #27176 from jakevdp:lax-docs
...
PiperOrigin-RevId: 737338493
2025-03-16 05:39:55 -07:00
jax authors
e8b683aee0
Update XLA dependency to use revision
...
936a727db7
.
PiperOrigin-RevId: 737338103
2025-03-16 05:37:58 -07:00
Joan Puigcerver
466ef6a132
Change the way that batching.spec_types is updated.
...
There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.
Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).
When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.
PiperOrigin-RevId: 737280270
2025-03-15 22:58:44 -07:00
Jake VanderPlas
de8b0564ce
Better docs for jax.lax add/sub/mul/div
2025-03-15 11:49:51 -07:00
jax authors
f360e19194
Update XLA dependency to use revision
...
f52d5e03ce
.
PiperOrigin-RevId: 737143950
2025-03-15 05:10:35 -07:00
Ayaka
9b0ace4a11
Support error checking in explicit mode
...
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00
jax authors
d07d642d6f
Merge pull request #27177 from jax-ml:mixing_modes
...
PiperOrigin-RevId: 737047069
2025-03-14 18:34:27 -07:00
Yash Katariya
3c0027af3b
mixing modes
2025-03-14 18:23:27 -07:00
jax authors
7db59cdcca
Merge pull request #27174 from mattjj:opt-barrier-ad-rules
...
PiperOrigin-RevId: 737040381
2025-03-14 17:59:07 -07:00
Peter Hawkins
14cb7453f0
Add a C++ implementation of a toplogical sort.
...
This is an exact port of the current Python implementation to C++ for speed.
I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.
PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
github-actions[bot]
7a172d7010
Merge pull request #285 from ROCm/ci-upstream-sync-147_1
...
CI: 03/14/25 upstream sync
2025-03-14 17:59:16 -05:00
Charles Hofer
7a6940bb7b
Trivial change for CI
2025-03-14 22:45:17 +00:00
GitHub Actions
e275d5cf6c
Merge remote-tracking branch 'origin/rocm-main' into ci-upstream-sync-147_1
2025-03-14 22:42:07 +00:00
Matthew Johnson
dadc68b6c1
add experimental lax.optimization_barrier autodiff rules
2025-03-14 22:40:55 +00:00
jax authors
b00a3a1986
Merge pull request #27015 from mattjj:direct-linearize-fixes-4
...
PiperOrigin-RevId: 737003323
2025-03-14 15:24:11 -07:00
Sergei Lebedev
64230d1c93
[pallas:mosaic_gpu] WG lowering now supports while_p
...
PiperOrigin-RevId: 736996154
2025-03-14 14:59:29 -07:00
charleshofer
022da913e6
Count test totals correctly for dashboards ( #280 )
...
* Account test totals correctly for dashboards
* Add blurb to the dev guide on skipping tests
* Remove extra newline
* Default to 0 if "skipped" isn't found
Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
---------
Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
2025-03-14 16:57:08 -05:00
Matthew Johnson
174dcc771a
[direct-linearize] shmap fixes
2025-03-14 21:38:50 +00:00
jax authors
95791fa9e4
Merge pull request #27173 from jakevdp:fix-ipynb
...
PiperOrigin-RevId: 736987967
2025-03-14 14:34:31 -07:00
Tzu-Wei Sung
21f5f2d45e
[Pallas] Increase #rows when casting to x2.
...
There is a bug in XLA on v5p.
PiperOrigin-RevId: 736987667
2025-03-14 14:32:33 -07:00
Jake VanderPlas
412b2e3acb
Fix notebook formatting
2025-03-14 14:20:50 -07:00
Daniel Suo
39e8ee93b0
Add experimental/serialize_executable.py
to BUILD
.
...
PiperOrigin-RevId: 736975882
2025-03-14 13:54:39 -07:00
Yash Katariya
aa9480a441
Expose get_abstract_mesh
via the jax.sharding
namespace
...
PiperOrigin-RevId: 736972976
2025-03-14 13:45:32 -07:00
jax authors
a11d8891ce
Merge pull request #27165 from jax-ml:sharding-in-types-doc
...
PiperOrigin-RevId: 736971523
2025-03-14 13:40:47 -07:00
Dougal
e8f43d1cef
Explicit sharding docs
2025-03-14 16:33:30 -04:00
Justin Fu
dbd8d92075
[Pallas] Add legacy PRNG key support to Pallas PRNG
...
PiperOrigin-RevId: 736949584
2025-03-14 12:30:04 -07:00
Zac Mustin
0c8e601f90
Support convolution in roofline.
...
So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity.
PiperOrigin-RevId: 736948528
2025-03-14 12:26:20 -07:00
Yash Katariya
88d4bc3d45
Rename AxisTypes enum to AxisType
...
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Emily Fertig
bdb6d03322
Allow make_array_from_callback
to construct nonaddressable arrays.
...
PiperOrigin-RevId: 736922870
2025-03-14 11:10:32 -07:00
Martin Muller
4a82fe94de
Use lax.top_k
instead of jnp.argsort
in Gumbel top-k trick for weighted sampling without replacement in jax.random.choice
2025-03-14 19:02:24 +01:00
Sergei Lebedev
97bbc37e83
[dlpack] Support more DLPack dtypes now that we target DLPack 1.1
...
I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor
TensorFlow currently support importing DLPack arrays with the new dtypes.
PiperOrigin-RevId: 736882904
2025-03-14 09:10:56 -07:00
Ilya Tikhonovskiy
c9ac82c826
[XLA:GPU] Add missing BF16_BF16_F32_X9 matmul option in config.py
...
Extend the list of possible default algorithms that dot could use.
PiperOrigin-RevId: 736879149
2025-03-14 08:58:59 -07:00
Nitin Srinivasan
5944c9ed65
Install test dependencies from test-requirements.txt instead of requirements.in
...
PiperOrigin-RevId: 736878834
2025-03-14 08:57:20 -07:00
Peter Hawkins
6fa98fc0a4
Use "x is y" rather than "id(x) == id(y)".
...
The latter involves at least two object constructions.
PiperOrigin-RevId: 736878098
2025-03-14 08:54:46 -07:00
jax authors
8fbe3b1333
Remove internal_test_util
folder and packages from jax
wheel.
...
PiperOrigin-RevId: 736861450
2025-03-14 07:52:03 -07:00
Peter Hawkins
8ab33669e2
Add a variant of safe_map() that has no return value, named foreach().
...
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.
I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.
PiperOrigin-RevId: 736859418
2025-03-14 07:42:48 -07:00
Peter Hawkins
074216e07a
Precompute a weakref to a Trace≥
...
We use Trace weakrefs frequently, so we may as well construct one eagerly.
PiperOrigin-RevId: 736841778
2025-03-14 06:26:17 -07:00
jax authors
92c57a51b9
Update XLA dependency to use revision
...
4c4aa96f9f
.
PiperOrigin-RevId: 736824693
2025-03-14 05:04:35 -07:00
Benjamin Chetioui
5098d2ef49
[Mosaic GPU][NFC] Simplify implementation for in_{layout,transforms}_for_operand
utils.
...
PiperOrigin-RevId: 736809960
2025-03-14 03:52:10 -07:00