25148 Commits

Author SHA1 Message Date
charleshofer
0a2f28056a
Update XLA commit hash (#342) 2025-04-07 13:29:47 -05:00
charleshofer
73f351886b
Port PR CI workflow from rocm-main (#312) 2025-04-02 11:48:31 -05:00
Mathew Odden
0074fed2ef
Fix C++23 build errors (#257) (#318)
Co-authored-by: charleshofer <Charles.Hofer@amd.com>
2025-03-26 15:04:49 -05:00
gabeweisz
82a2ec711c
Fix permissions on /tmp (#320) 2025-03-26 15:04:38 -05:00
Ruturaj Vaidya
7613a24b58
Update offload ROCm architectures (#311) 2025-03-25 09:39:22 -05:00
Ruturaj Vaidya
6a5394711b
Set higher tolerance for failing fp8 tests (#306) 2025-03-24 12:08:52 -05:00
Mathew Odden
887d1ede47
Fixes for 0.5.0 build (#308) 2025-03-22 09:21:42 +05:30
Mathew Odden
0218c7bd1e
Add support for ROCm wheel based install (#283)
This also requires some changes on the XLA side
for the paths and such to have any effect.

The plugin init now looks for a `rocm` python
package install and extracts ROCm toolkit paths
from the python packages that it finds. We hand
these into the XLA portions of the plugin via
environment variables to avoid changing any interfaces
like protobuf or PJRT C APIs.

We also have to patch the rpath in the shared object files
included in the plugin and kernel wheels so they look
relative to their install path just like the cuda based
plugin objects do.

Some other changes are fixing missing dynamic link libraries
and also adding an optional feature target to pull in
rocm python dependencies for the plugin.
2025-03-21 11:39:56 -05:00
Ruturaj Vaidya
41bcf9f1ae
Fix and enable RNN tests on JAX 0.5.0 branch (#301) 2025-03-20 15:16:03 -05:00
Ruturaj Vaidya
ce33728473
Fix the latest XLA has for 0.5.0 JAX branch (#302) 2025-03-20 15:01:54 -05:00
Gulsum Gudukbay Akbulut
d9ec58fd44
Fix dev build script (#286)
* fixed the error in the command building part of the dev_build_rocm.py script

* Update build/rocm/dev_build_rocm.py

Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>

---------

Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
2025-03-17 15:39:50 -05:00
Mathew Odden
1c96ac20e3
Fix auditwheel version issue (#290)
Auditwheel 6.3.0 changed/removed the lddtree function
so cap constraint to 6.2.x

(cherry picked from commit 57e77ce3a82d93848ccba557b007084605b40f22)
2025-03-17 12:30:48 -05:00
JD
a496792a7a
Deprecate unused pre-production gfx versions (#272) 2025-03-13 16:11:55 -05:00
Dragan Mladjenovic
8445686bee
[pallas:triton] Fix atomic min/max lowering for unsigned integers and float types (#264) 2025-03-10 10:42:53 -05:00
Zahid Iqbal
9ce2aec622
removing pytest-csv and removing csv logs from unit testing (#261) 2025-03-05 13:37:08 -06:00
JD
48a361518a
add gfx1101 target (#249) (#254) 2025-03-05 11:44:39 -06:00
Ruturaj Vaidya
772bd4ed4b
Fix sha256 hash to point to the latest xla commit (#251) 2025-03-03 09:23:28 -06:00
Ruturaj Vaidya
790fe78b82
Skip tests on ROCm (#235) 2025-02-20 09:41:11 -06:00
Mathew Odden
6b35155294
Fix invalid lowerings for ROCm in Pallas (#223)
popcount and clz were effectively broken on ROCm,
since math_dialect had incorrect lowerings.

Use the device intrinsics for these functions, as
well as for exp and absf, which fixes some accuracy issues in
the pallas tests.

Docs for OCML/OCKL

- https://github.com/ROCm/llvm-project/blob/amd-staging/amd/device-libs/doc/OCML.md
- https://github.com/ROCm/llvm-project/blob/amd-staging/amd/device-libs/doc/OCKL.md
2025-02-14 11:27:52 -06:00
Ruturaj Vaidya
1fc3d15727
Use HIPBLAS_V2 (#222)
Co-authored-by: Harsha HS <Harsha.HavanurShamsundara@amd.com>
2025-02-05 18:12:13 -06:00
Zahid Iqbal
b6843ff873
Merge pull request #215 from ROCm/5.0_xla_sha_fix
Update workspace.bzl
2025-01-28 00:58:02 -06:00
Ruturaj Vaidya
1c9f907338
Update workspace.bzl 2025-01-28 00:56:43 -06:00
Ruturaj Vaidya
cdf1ef5aba
Update XLA hash for 0.5.0 branch (#212)
Co-authored-by: Ruturaj4 <ruvaidya@amd.com>
rocm-jax-v0.5.0
2025-01-23 20:06:35 -06:00
Zahid Iqbal
81ac86ba17
Merge pull request #211 from ROCm/jax_5.0_xla_sha_fix
Update XLA hash for 0.5.0 branch
2025-01-22 22:18:23 -06:00
Ruturaj4
b51b0c76a8 Update XLA hash for 0.5.0 branch 2025-01-22 18:52:31 -06:00
Peter Hawkins
c25fb92c44 Release JAX 0.5.0 2025-01-17 10:28:03 -05:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Yash Katariya
ce85b89884 [sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
2025-01-17 06:58:29 -08:00
jax authors
7cac76d346 Update XLA dependency to use revision
a91d9c30c5.

PiperOrigin-RevId: 716651296
2025-01-17 06:51:32 -08:00
Benjamin Chetioui
d3be190efb [Mosaic GPU] Delete unused declarations of mosaic_gpu_memcpy_async_h2d.
PiperOrigin-RevId: 716616807
2025-01-17 04:34:48 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Adam Paszke
bd22bfef71 [Mosaic TPU] Use large to compact 2nd minor retiling for conversions going both ways
This specific retiling is its own inverse and it faster than alternatives.

PiperOrigin-RevId: 716360070
2025-01-16 13:35:26 -08:00
Robert Dyro
aa9cea0a55 Add quotes to pip commands in docs around option install for zsh
PiperOrigin-RevId: 716358130
2025-01-16 13:30:43 -08:00
Nitin Srinivasan
5e52031dac Store GCS upload URI as a step output
This commit stores the GCS upload URI as a step output and then maps the job output to it so that we can make it available when a workflow call is made to `build_artifacts.yml`.

This helps in the case where we want to debug a workflow such as https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_continuous.yml where artifact build jobs and test jobs are run separately. Previously, we could not re-try a failing test workflow as the upload URI contains `${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}`. In GitHub, `github.run_attempt` is mapped to a workflow and not individual jobs so even a re-trigger of a test job alone would lead to the `run_attempt` value being increased which in turn invalidates the GCS download URI that it reads.

By storing the the upload URI as a step output, we freeze the upload/download URIs until the build artifact jobs are re-run. However, note that this still has a edge case where things can break - `run-pytest-cuda` job in `wheel_tests_continuous.yml` depends on both `build-jaxlib-artifact` and  `build-cuda-artifacts` but consumes the upload URI from the output of `build-jaxlib-artifact` alone. This is done on the assumption that both these jobs will have uploaded to the same location. However, that would not be the case if one of these jobs fail and have to re-run. We are working on a longterm solution for this case but in the meantime, the recommendation for now is just to re-run the whole set of jobs again.

PiperOrigin-RevId: 716348745
2025-01-16 13:05:52 -08:00
Parker Schuh
f2f552c108 Allow resharding between tokens on a single device
and multiple devices.

Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.

Fixes https://github.com/jax-ml/jax/issues/25671.

PiperOrigin-RevId: 716309978
2025-01-16 11:24:22 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
jax authors
994c3f59e2 Update XLA dependency to use revision
43d0d40456.

PiperOrigin-RevId: 716275042
2025-01-16 09:58:56 -08:00
Tzu-Wei Sung
5c020ee317 [Mosaic] Fix infer/apply extensions.
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.

PiperOrigin-RevId: 716274818
2025-01-16 09:57:14 -08:00
Yash Katariya
0df4475aeb Make result_handler of _DeferredShardArg a method instead of a property. Also play some code golf.
PiperOrigin-RevId: 716273533
2025-01-16 09:53:48 -08:00
Adam Paszke
8954e71d73 [Mosaic TPU] Improve support for int16->int32 casts in TPUv4
PiperOrigin-RevId: 716250236
2025-01-16 08:44:10 -08:00
Dimitar (Mitko) Asenov
5e27efd0e0 [MosaicGPU] Cleanup imports in dialect_lowering.py
PiperOrigin-RevId: 716244938
2025-01-16 08:26:02 -08:00
Benjamin Chetioui
6746d63364 [Mosaic GPU][NFC] Clean up import to align with stylistic guidance.
PiperOrigin-RevId: 716233876
2025-01-16 07:50:04 -08:00
Benjamin Chetioui
d3bf243342 [Mosaic GPU] Add layout inference for splat arith.ConstantOps and vector.SplatOps.
PiperOrigin-RevId: 716224880
2025-01-16 07:18:35 -08:00
Dimitar (Mitko) Asenov
24884071b9 [MosaicGPU] Remove the single_thread context from top-level dialect code.
- Change the `async_load` lowering to manage the single thread context.
- Use a predicate for the top-level arrive_expect. If we want to hide this further, we can have a warp-group level op that lowers to a single-threaded context.

PiperOrigin-RevId: 716219730
2025-01-16 06:59:32 -08:00
Benjamin Chetioui
3366c92782 [Mosaic GPU][NFC] Simplify and clean up layout inference tests to use FuncOps.
PiperOrigin-RevId: 716216260
2025-01-16 06:48:57 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
Dimitar (Mitko) Asenov
ce03cf976e [MosaicGPU] Move gpu_address_space_to_nvptx inside utils.py and use it.
PiperOrigin-RevId: 716214822
2025-01-16 06:41:51 -08:00