25125 Commits

Author SHA1 Message Date
Jake VanderPlas
f83175fc94 [key reuse] fix signature for device_put 2025-01-17 09:47:50 -08:00
jax authors
a4a657bc43 Merge pull request #25952 from johannahaffner:patch-1
PiperOrigin-RevId: 716678666
2025-01-17 08:26:16 -08:00
Johanna Haffner
df6140e875
Tweak documentation of jnp.cov to include scalar return for M = 1
Fixes https://github.com/jax-ml/jax/issues/25951
2025-01-17 16:16:06 +01:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Yash Katariya
ce85b89884 [sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
2025-01-17 06:58:29 -08:00
jax authors
7cac76d346 Update XLA dependency to use revision
a91d9c30c5.

PiperOrigin-RevId: 716651296
2025-01-17 06:51:32 -08:00
Benjamin Chetioui
d3be190efb [Mosaic GPU] Delete unused declarations of mosaic_gpu_memcpy_async_h2d.
PiperOrigin-RevId: 716616807
2025-01-17 04:34:48 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Adam Paszke
bd22bfef71 [Mosaic TPU] Use large to compact 2nd minor retiling for conversions going both ways
This specific retiling is its own inverse and it faster than alternatives.

PiperOrigin-RevId: 716360070
2025-01-16 13:35:26 -08:00
Robert Dyro
aa9cea0a55 Add quotes to pip commands in docs around option install for zsh
PiperOrigin-RevId: 716358130
2025-01-16 13:30:43 -08:00
Nitin Srinivasan
5e52031dac Store GCS upload URI as a step output
This commit stores the GCS upload URI as a step output and then maps the job output to it so that we can make it available when a workflow call is made to `build_artifacts.yml`.

This helps in the case where we want to debug a workflow such as https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_continuous.yml where artifact build jobs and test jobs are run separately. Previously, we could not re-try a failing test workflow as the upload URI contains `${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}`. In GitHub, `github.run_attempt` is mapped to a workflow and not individual jobs so even a re-trigger of a test job alone would lead to the `run_attempt` value being increased which in turn invalidates the GCS download URI that it reads.

By storing the the upload URI as a step output, we freeze the upload/download URIs until the build artifact jobs are re-run. However, note that this still has a edge case where things can break - `run-pytest-cuda` job in `wheel_tests_continuous.yml` depends on both `build-jaxlib-artifact` and  `build-cuda-artifacts` but consumes the upload URI from the output of `build-jaxlib-artifact` alone. This is done on the assumption that both these jobs will have uploaded to the same location. However, that would not be the case if one of these jobs fail and have to re-run. We are working on a longterm solution for this case but in the meantime, the recommendation for now is just to re-run the whole set of jobs again.

PiperOrigin-RevId: 716348745
2025-01-16 13:05:52 -08:00
Parker Schuh
f2f552c108 Allow resharding between tokens on a single device
and multiple devices.

Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.

Fixes https://github.com/jax-ml/jax/issues/25671.

PiperOrigin-RevId: 716309978
2025-01-16 11:24:22 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
jax authors
994c3f59e2 Update XLA dependency to use revision
43d0d40456.

PiperOrigin-RevId: 716275042
2025-01-16 09:58:56 -08:00
Tzu-Wei Sung
5c020ee317 [Mosaic] Fix infer/apply extensions.
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.

PiperOrigin-RevId: 716274818
2025-01-16 09:57:14 -08:00
Yash Katariya
0df4475aeb Make result_handler of _DeferredShardArg a method instead of a property. Also play some code golf.
PiperOrigin-RevId: 716273533
2025-01-16 09:53:48 -08:00
Adam Paszke
8954e71d73 [Mosaic TPU] Improve support for int16->int32 casts in TPUv4
PiperOrigin-RevId: 716250236
2025-01-16 08:44:10 -08:00
Dimitar (Mitko) Asenov
5e27efd0e0 [MosaicGPU] Cleanup imports in dialect_lowering.py
PiperOrigin-RevId: 716244938
2025-01-16 08:26:02 -08:00
Benjamin Chetioui
6746d63364 [Mosaic GPU][NFC] Clean up import to align with stylistic guidance.
PiperOrigin-RevId: 716233876
2025-01-16 07:50:04 -08:00
Benjamin Chetioui
d3bf243342 [Mosaic GPU] Add layout inference for splat arith.ConstantOps and vector.SplatOps.
PiperOrigin-RevId: 716224880
2025-01-16 07:18:35 -08:00
Dimitar (Mitko) Asenov
24884071b9 [MosaicGPU] Remove the single_thread context from top-level dialect code.
- Change the `async_load` lowering to manage the single thread context.
- Use a predicate for the top-level arrive_expect. If we want to hide this further, we can have a warp-group level op that lowers to a single-threaded context.

PiperOrigin-RevId: 716219730
2025-01-16 06:59:32 -08:00
Benjamin Chetioui
3366c92782 [Mosaic GPU][NFC] Simplify and clean up layout inference tests to use FuncOps.
PiperOrigin-RevId: 716216260
2025-01-16 06:48:57 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
Dimitar (Mitko) Asenov
ce03cf976e [MosaicGPU] Move gpu_address_space_to_nvptx inside utils.py and use it.
PiperOrigin-RevId: 716214822
2025-01-16 06:41:51 -08:00
Adam Paszke
ef4dbd9cb9 [Mosaic TPU] Add support for packing to 16-bit integers on TPUv4
And refactor some test conditions to better match what we really support.
The tests were failing on older TPUs.

PiperOrigin-RevId: 716214098
2025-01-16 06:39:23 -08:00
Dimitar (Mitko) Asenov
22417ae28e [MosaicGPU] Extract code into a new method BarrierRef.from_dialect_barrier_memref and implement support for 1D barrier memrefs.
PiperOrigin-RevId: 716180182
2025-01-16 04:30:43 -08:00
Benjamin Chetioui
bc7204f003 [Mosaic GPU] Allow querying layouts from a FuncOp's block arguments if set.
The motivation behind this change is twofold:

1. it simplifies test writing (no need to produce arbitrary, manual, non-splat
   constants to produce arguments with a strided layout);
2. it'll allow running layout inference on different `FuncOp`s in isolation,
   before inlining.

While the primary motivation is to simplify test writing for upcoming changes,
`2.` is useful if we ever intend to call functions whose body's layout we have
inferred from other functions. It's not clear to me that we have a use case for
that, but the theoretical benefit is worth pointing out.

Crucially, layout inference does not set default layouts for `FuncOp`s, since
the caller may choose a different layout for its arguments. As a result, there
is also no layout inference rule for `func.FuncOp`.

PiperOrigin-RevId: 716158516
2025-01-16 03:05:41 -08:00
Sergei Lebedev
4221f109d1 [mosaic] Extracted serialization pass traversal logic into a reusable function
I will use it to implement Mosaic GPU serialization pass in a follow up.

PiperOrigin-RevId: 716156650
2025-01-16 02:58:06 -08:00
jax authors
9a60e6fce4 Merge pull request #25917 from ROCm:ci_fix_multi_gpu_test_logic-upstream
PiperOrigin-RevId: 716153760
2025-01-16 02:45:54 -08:00
Sharad Vikram
0ac63157f5 [Pallas TPU] Add helpers file with copy_ref function
PiperOrigin-RevId: 716030813
2025-01-15 18:34:58 -08:00
Tzu-Wei Sung
4a9cc9ffc1 [Mosaic] Allow passing ApplyVectorLayoutCtx to tpu.apply_layout_op.
To make it the same with C++ API. While I'm here, fix a bug in test_concatenate.

PiperOrigin-RevId: 716016244
2025-01-15 17:47:36 -08:00
Ruturaj4
8e88adcd3f Fix run_multi_gpu script multi-gpu issue and refactor code 2025-01-15 22:33:03 +00:00
Naums Mogers
d3ba1eb339 [Mosaic] Add a macro to convert abseil StatusOr to LLVM FailureOr
PiperOrigin-RevId: 715943314
2025-01-15 14:19:29 -08:00
Nitin Srinivasan
8a053af1ce Move halt for testing step to be just before running tests
This lets all the setup steps to finish before a halt for connection request is made.

PiperOrigin-RevId: 715887557
2025-01-15 11:54:36 -08:00
jax authors
cf67e28f79 Merge pull request #25906 from ROCm:ci_add_new_gfx-upstream
PiperOrigin-RevId: 715883737
2025-01-15 11:45:09 -08:00
jax authors
2fa1002054 Merge pull request #25911 from hawkinsp:version
PiperOrigin-RevId: 715882985
2025-01-15 11:43:23 -08:00
Zachary Garrett
f7d097f7cc Make utils for reporting function name work with functools.partial by using the inner .func attribute if the object doesn't have a __name__ attribute. functools.partial objects do not have __name__ attributes by default.
PiperOrigin-RevId: 715881812
2025-01-15 11:40:59 -08:00
Peter Hawkins
3a8f31aa83 Update the JAX version to 0.5.0.
This is because of the breaking change to PRNG key semantics, and the version follows JAX's new effver versioning scheme (https://jax.readthedocs.io/en/latest/jep/25516-effver.html).
2025-01-15 14:08:15 -05:00
jax authors
41993fdb24 Merge pull request #25755 from ROCm:ci_rnn_final-upstream
PiperOrigin-RevId: 715856939
2025-01-15 10:40:54 -08:00
jax authors
ca012d7ad6 Merge pull request #25864 from jax-ml:yet-more-linearization-fixes
PiperOrigin-RevId: 715840148
2025-01-15 10:00:31 -08:00
jax authors
51f2310069 Update XLA dependency to use revision
370a76e2d5.

PiperOrigin-RevId: 715838120
2025-01-15 09:55:41 -08:00
Zac Mustin
2d72e8de84 Jax: Stop returning a list of cost-analyses.
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
jax authors
70c1ee5d9c Merge pull request #25876 from gnecula:debug_info_3
PiperOrigin-RevId: 715831527
2025-01-15 09:35:03 -08:00
Ruturaj4
435edf1f8c Add gfx12xx archs 2025-01-15 16:14:40 +00:00
jax authors
2e5e4799fd Merge pull request #25880 from jakevdp:fix-gather
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Dougal
9fe553ca49 More linearization fixes 2025-01-15 10:27:21 -05:00
Sergei Lebedev
afcb21ddf1 [pallas:mosaic_gpu] Fixed a crash in MLIR Python bindings
The error message produced by MLIR is not really clear, but AFAICT the crash
was caused by the "temporary module" hack we use in the lax.cond lowering
rule.

PiperOrigin-RevId: 715785632
2025-01-15 07:09:43 -08:00