1281 Commits

Author SHA1 Message Date
jax authors
de68018473 [NFC][Mosaic TPU] Clarify layout comment block
PiperOrigin-RevId: 690977672
2024-10-29 05:20:08 -07:00
jax authors
12d26053e3 [TPU][Mosaic] Add support for a no-op reshape where sublane_tiling = 1 and the res_tiled and src_tiled shapes both fill a full vreg (1024)
PiperOrigin-RevId: 690796348
2024-10-28 16:57:51 -07:00
Adam Paszke
36c56fa19b [Pallas:MGPU] Fix flaky debug_print tests
Turns out that waiting for the kernel to finish it not enough, since the
prints also need to be processed by the CUDA runtime. Using a test-only
function that synchronizes all the devices seems to suffice.

PiperOrigin-RevId: 690624999
2024-10-28 08:42:02 -07:00
Sergei Lebedev
04bdd07f66 [mosaic_gpu] mgpu.FragmentedArray now supports //
This is needed to compute grid index from the iteration step counter in `emit_pipeline`.

PiperOrigin-RevId: 690608581
2024-10-28 07:52:22 -07:00
Jevin Jiang
2a671e25a7 [Mosaic TPU] Remove extra check
PiperOrigin-RevId: 689852989
2024-10-25 11:22:17 -07:00
Tzu-Wei Sung
4972f84c94 [Mosaic] Use max sublane offset per shuffled load to decide whether to avoid bank conflict.
PiperOrigin-RevId: 689809024
2024-10-25 09:09:14 -07:00
jax authors
63c1699ed0 Fix a use-after-free bug in third_party/py/jax/jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc
The backing array of the initializer_list is destroyed at the end of the full expression.

PiperOrigin-RevId: 689783482
2024-10-25 07:40:12 -07:00
Kanglan Tang
af28595909 Add a jax_wheel Bazel rule to build jax pip packages
PiperOrigin-RevId: 689514531
2024-10-24 14:20:46 -07:00
Adam Paszke
6634f5a348 [Mosaic GPU] Use absl::StrCat instead std::string::operator+
Repeated string addition is apparently a bit of an anti-pattern. Not that it matters
much in this place, but why not do it properly.

PiperOrigin-RevId: 689416587
2024-10-24 09:49:51 -07:00
Andrey Portnoy
14e0f0e7fa [Mosaic GPU] Query SM and PTX ISA dynamically using driver and LLVM
Originally proposed in #24021. Slightly rewritter to make testing with internal LLVM toolchains better.

Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`).
Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`).
Then use LLVM to map the SM string to a PTX ISA version that supports the SM.

Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
PiperOrigin-RevId: 689286774
2024-10-24 01:46:29 -07:00
Jevin Jiang
b8bacda2d9 [Mosaic TPU] Use native vector tiling to load and store with untiled memref.
PiperOrigin-RevId: 689142734
2024-10-23 16:22:16 -07:00
jax authors
48bddc6f6c Adds arith.select to the op patters in order to canonicalize non 32 bit selects.
PiperOrigin-RevId: 687635492
2024-10-19 09:09:06 -07:00
Benjamin Chetioui
ade480ff05 Add a dialect for Mosaic GPU.
PiperOrigin-RevId: 687325692
2024-10-18 09:11:31 -07:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
jax authors
6c2649fdf2 Rewrite mosaic concat to support operand shapes that do not align with native shapes, Expand tests to cover multi operand, batch dim concat, etc.
PiperOrigin-RevId: 687003778
2024-10-17 12:24:51 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
jax authors
9027fb38fe Fix segfault
PiperOrigin-RevId: 686821923
2024-10-17 01:52:44 -07:00
Jevin Jiang
a47b755619 [Mosaic TPU] Support native int4 @ int4
PiperOrigin-RevId: 686179715
2024-10-15 11:35:23 -07:00
Yash Katariya
824ccd7183 [Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Bart Chrzaszcz
75e22f2ccd #sdy Run inlined mesh lifter pass at the end of JAX lowering.
PiperOrigin-RevId: 685728692
2024-10-14 09:13:12 -07:00
jax authors
57ef7a4a59 Merge pull request #24274 from ROCm:ci_linalg_fix
PiperOrigin-RevId: 685717437
2024-10-14 08:33:33 -07:00
Paweł Paruzel
23fdb91252 Port Schur Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685689593
2024-10-14 06:46:42 -07:00
Paweł Paruzel
ec68d420fe Port Tridiagonal Reduction to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 685679646
2024-10-14 06:02:59 -07:00
Ruturaj4
ee223d4004 [ROCm] jaxlib linalg fix 2024-10-13 20:25:18 -05:00
jax authors
e4629f6a4c Merge pull request #24232 from ROCm:ci_rv_clang_clean
PiperOrigin-RevId: 684891301
2024-10-11 11:00:55 -07:00
Ruturaj4
89cd375c85 [JAX] bazel build rocm changes 2024-10-10 18:00:15 -05:00
Ruturaj4
33bcd0cb7a [ROCm] Bring up clang support for JAX+XLA
* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
2024-10-10 16:31:26 -05:00
Dan Foreman-Mackey
6625a2b3ed Update Eigh kernel on GPU to use 64-bit interface when it is available.
Part of https://github.com/jax-ml/jax/issues/23413

PiperOrigin-RevId: 684546802
2024-10-10 12:59:37 -07:00
Peter Hawkins
cf5f15773a Remove dead ducc_fft code.
I guess this was omitted when we switched over to using stablehlo.fft since XLA now calls DUCC itself.

PiperOrigin-RevId: 684437739
2024-10-10 07:33:54 -07:00
jax authors
81a95f78b9 [Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.
PiperOrigin-RevId: 684392184
2024-10-10 04:28:36 -07:00
Jevin Jiang
f52b016de1 [Mosaic TPU] Change getLayout to force offset to 0 when inferring input has offset out of the first tile.
PiperOrigin-RevId: 684145987
2024-10-09 13:11:49 -07:00
Jevin Jiang
f96c5661ac [Mosaic TPU][NFC] Refactor tpu matmul rule.
* Separate MXU size to MXU contracting size and MXU non-contracting size.
* Rename tile to group for MXU shaped tiling since tile is overused in Mosaic.

PiperOrigin-RevId: 684116306
2024-10-09 11:45:25 -07:00
jax authors
9748e2ab1a [JAX] Fix error message for matmul operand shape check.
PiperOrigin-RevId: 683778484
2024-10-08 15:07:20 -07:00
Eric Salo
713e909ba0 cleanup: remove api_version from BUILD files
PiperOrigin-RevId: 683658237
2024-10-08 09:44:15 -07:00
Peter Hawkins
145304a0e0 Remove reference to outfeed_receiver.pyi, which was deleted.
PiperOrigin-RevId: 683195999
2024-10-07 08:37:14 -07:00
Dan Foreman-Mackey
67f24df740 Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks.

PiperOrigin-RevId: 682415506
2024-10-04 12:38:26 -07:00
Dan Foreman-Mackey
c0240764bc Activate FFI implementation of the QR decomposition.
As part of this change, I've added support and tests for shape polymorphism and export on CPU and GPU.

The FFI kernels have been available in jaxlib for over 3 weeks already and they are included with the latest release of jaxlib on PyPI so we don't need to worry about the forward compatibility checks. With this in mind, I also removed the old lowering rules, but kept the backwards compatibility tests for now.

PiperOrigin-RevId: 682312752
2024-10-04 07:27:11 -07:00
Paweł Paruzel
6e9a53690c Activate Hessenberg Decomposition to XLA's FFI
Additionally, created a missing backward compatibility test for the old LAPACK kernels of Hessenberg Decomposition.

PiperOrigin-RevId: 681047625
2024-10-01 09:20:06 -07:00
Adam Paszke
f62941d126 [Mosaic TPU] The previous change does not actually force the input offsets read by the rules, but simply disables all the checks. Reverting so that we at least regain the checks until we have a proper fix.
Reverts 4a596aee1e8920f5b51d5bd573df976390bbd437

PiperOrigin-RevId: 680925509
2024-10-01 02:23:52 -07:00
Jevin Jiang
4a596aee1e [Mosaic TPU] Force offset to 0 when inferring input has offset out of the first tile.
We still have this temporary check in apply vector layout, but in infer vector layout, instead of throwing error, we should just reset offset to zero. Because some ops which has relaxed this restriction might be passed as input for un-relaxed ops and cause failure.

PiperOrigin-RevId: 680706301
2024-09-30 13:52:48 -07:00
Jevin Jiang
7e2f487ada [Mosaic TPU] Canonicalize arith.select's condition to vector if other types are vector.
This fixes the failure in elementwise rule of apply vector layout pass.

If the condition scalar is static, it will be simplified to corresponding vector from true value and false value by MLIR.

If the condition scalar is dynamic, we want to use vselect over scf.if anyway. Because latter creates a inner region.

PiperOrigin-RevId: 680674560
2024-09-30 12:26:44 -07:00
Dan Foreman-Mackey
1a1e16abcc Remove forward compatibility checks from lowering of LU decomposition.
The forward compatibility window for these checks has passed so it is now safe to remove them.

PiperOrigin-RevId: 680565099
2024-09-30 07:23:56 -07:00
Peter Hawkins
5a1d0a6c26 Include the sdy MLIR dialect in jaxlib.
We're seeing test failures from tests assuming that this dialect exists. But given we plan to enable it at some point, we may as well just include it in the build.

The size impact is small (around 400K uncompressed).

PiperOrigin-RevId: 679608092
2024-09-27 08:53:31 -07:00
Peter Hawkins
26632fd344 Replace disable_backends with enable_backends on jax_multiplatform_test.
Most users of disable_backends were actually using it to enable only a single backend. So things are simpler if we negate the sense of the option to say that. Change disable_configs to enable_configs, with a default `None` value meaning "everything is enabled".

We change the relationship between enable_backends, disable_configs, enable_configs to be the following:
* `enable_backends` selects a set of initial test configurations to enable, based off backend only.
* `disable_configs` then prunes that set of test configurations, removing elements from the set.
* `enable_configs` then adds additional configurations to the set.

Fix code in jax/experimental/mosaic/gpu/examples not to depend on a Google-internal GPU support target.

PiperOrigin-RevId: 679563155
2024-09-27 06:15:31 -07:00
Justin Fu
9f4e8d0039 [XLA:Mosaic][Pallas] Enable vector.ExtractOp for non-zero indices.
PiperOrigin-RevId: 679283281
2024-09-26 13:57:45 -07:00
Jevin Jiang
e4ca4f5a57 Roll back cl/678765762 [Mosaic TPU] Support bitcast without forcing retiling.
Reverts 37641dd4fade625563321b7e1e87165df23cf4a8

PiperOrigin-RevId: 678881199
2024-09-25 16:02:58 -07:00
Jevin Jiang
37641dd4fa [Mosaic TPU] Support bitcast without forcing retiling.
PiperOrigin-RevId: 678765762
2024-09-25 10:57:09 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Jevin Jiang
407dc774f7 [Mosaic TPU] Support all cases for extui.
PiperOrigin-RevId: 678331795
2024-09-24 11:35:03 -07:00
jax authors
2c85465ebe Merge pull request #23806 from gspschmid:gschmid/ffi-ext-bundle
PiperOrigin-RevId: 678273475
2024-09-24 09:05:20 -07:00