1056 Commits

Author SHA1 Message Date
Jevin Jiang
c1d3c2db9f [Mosaic TPU] Fix mosaic alignment check in concatenate rule.
PiperOrigin-RevId: 670837792
2024-09-03 22:57:27 -07:00
Paweł Paruzel
414eb90f5b Activate Householder Product to XLA's FFI
PiperOrigin-RevId: 670196460
2024-09-02 06:19:01 -07:00
Peter Hawkins
1ab3119d43 Add some msan suppressions to the LAPACK symmetric eigendecomposition FFI call.
This fixes some msan false positives in our CI, since we do not msan-instrument Fortran code.

PiperOrigin-RevId: 669385248
2024-08-30 11:12:45 -07:00
Paweł Paruzel
4342c0c0f3 Determine LAPACK workspace during Householder Product Kernel runtime
Workspace dependency was removed, and the info parameter is ignored now.

PiperOrigin-RevId: 669246058
2024-08-30 02:06:16 -07:00
Sergei Lebedev
7dd9adba05 Fixed stack-use-after-scope in Mosaic GPU
PiperOrigin-RevId: 668958750
2024-08-29 09:07:58 -07:00
Jevin Jiang
a3cccd34e2 [Mosaic TPU] Print expected Mosaic version after finding unsupported version.
PiperOrigin-RevId: 668632116
2024-08-28 15:33:31 -07:00
Jevin Jiang
b01075054a [Mosaic TPU] Support memref bitcast.
If element bitwidth changes, the ratio of bitwidth is multiplied to the 2nd minormost dim size and the leading dim in tiling. For example, we can bitcast Memref<8x128xf32> with tiling (8, 128) to Memref<16x128xi16> with tiling (16, 128).

PiperOrigin-RevId: 668619683
2024-08-28 15:00:46 -07:00
Kevin Gleason
78d5b75b0d Trim StableHLO python binding dependencies
With proper CAPI in place these dependencies are no longer needed, llvm support needed for string ostream for string APIs.

PiperOrigin-RevId: 668476145
2024-08-28 09:01:15 -07:00
Paweł Paruzel
3c6103f2df Activate Eigenvalue Decompositions to XLA's FFI
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.

PiperOrigin-RevId: 668381949
2024-08-28 03:53:49 -07:00
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Kevin Gleason
d72104de59 Use StableHLO filegroup for python APIs in jaxlib MLIR build.
PiperOrigin-RevId: 666450684
2024-08-22 12:36:39 -07:00
Adam Paszke
9c3f2dcefc [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest
XLA runtime creates a context per device, so we need to make sure that a kernel is loaded
separately on each device.

PiperOrigin-RevId: 666353098
2024-08-22 08:06:37 -07:00
Dan Foreman-Mackey
b56ed8eedd Port GPU kernel for Householder transformation to FFI.
PiperOrigin-RevId: 666305682
2024-08-22 05:23:09 -07:00
Paweł Paruzel
4786930a4c Determine LAPACK workspace during Eigenvalue Kernels runtime
PiperOrigin-RevId: 666285759
2024-08-22 04:09:34 -07:00
Paweł Paruzel
a72d46c549 Ignore LAPACK info parameter for QR Factorization
The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call.

PiperOrigin-RevId: 666241215
2024-08-22 01:38:38 -07:00
Krishna Haridasan
3713b966c2 Fix a potential segfault in triton kernel call caching
It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call
in case there is an error later during initialization. This change updates the cache to store either
an error or a valid kernel call.

PiperOrigin-RevId: 666161091
2024-08-21 20:45:35 -07:00
Benjamin Kramer
0105254ab1 Unbreak Mosaic after 42944da5ba
PiperOrigin-RevId: 665973530
2024-08-21 11:59:09 -07:00
jax authors
6a2a96c3b8 Merge pull request #23166 from ROCm:ci_bazel_build
PiperOrigin-RevId: 665935164
2024-08-21 10:27:51 -07:00
Ruturaj4
c41d644886 [ROCm] Fix bazel build issue 2024-08-21 08:40:54 -05:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
jax authors
607ee3eea1 Merge pull request #23143 from vfdev-5:more-nogil-mlir-dialects
PiperOrigin-RevId: 665464355
2024-08-20 11:58:28 -07:00
vfdev-5
da77b710b8 Added py::mod_gil_not_used() to PYBIND11_MODULE for _triton_ext and _tpu_ext
Description:
- Added `py::mod_gil_not_used()` to `PYBIND11_MODULE` for `_triton_ext` and `_tpu_ext`.

Refs:
- https://py-free-threading.github.io/porting/#__tabbed_1_2

Context:
- https://github.com/google/jax/issues/23073
2024-08-20 15:08:36 +02:00
Dan Foreman-Mackey
bd90968a25 Port the GPU Cholesky update custom call to the FFI.
PiperOrigin-RevId: 665319689
2024-08-20 05:46:03 -07:00
Dan Foreman-Mackey
71a93d0c87 Port QR factorization GPU kernel to FFI.
The biggest change here is that we now ignore the `info` parameter that is returned by `getrf`. In the previous implementation, we would return an error in the batched implementation, or set the relevant matrix entries to NaN in the non-batched version if `info != 0`. But, since info is only used for shape checking (see LAPACK, cuBLAS and cuSolver docs), I argue that we will never see `info != 0`, because we're including all the shape checks in the kernel already.

PiperOrigin-RevId: 665307128
2024-08-20 05:07:04 -07:00
vfdev-5
6546c4810b Added PyUnstable_Module_SetGIL to PyInit_cpu_feature_guard 2024-08-20 02:32:37 +02:00
vfdev-5
b1b3ea276b Added py::mod_gil_not_used() to PYBIND11_MODULE register_jax_dialects 2024-08-20 00:03:56 +02:00
Dan Foreman-Mackey
30d54ec6ff Refactor FFI shape inference functions to include dimension check.
Previously we always had two steps when extracting the batch size: (1) check the buffer has enough dimensions, (2) get the shape. And, in a few cases, this first check was missing. Now these steps are combined into one function that returns a StatusOr.

As part of this, I needed to fix our implementation of the `ASSIGN_OR_RETURN` macro to properly handle parentheses.

PiperOrigin-RevId: 664803225
2024-08-19 07:41:28 -07:00
Dan Foreman-Mackey
dad2f576ac Add support for shape polymorphism in ffi_lowering and move lu_pivots_to_permutation lowering out of jaxlib.
The lowering logic for all jaxlib custom calls are currently split between JAX and jaxlib for reasons that are harder to justify now that the compiled calls are split between jaxlib and the relevant plugins. As part of my project to update these calls and simplify the lowering logic, it makes sense to consolidate the lowering rules in JAX instead of jaxlib since the logic is now the same for both GPU and CPU. This update tackles a simple kernel as a test case for what this would look like.

Since the full lowering rule is now implemented in JAX, we can take advantage of the MLIR helpers that are included there, including `jex.ffi.ffi_lowering`, which I needed to update to support shape polymorphism.

Of note: I think it is safe (in a compatibility sense) to delete the lowering code from jaxlib, but it does mean that it won't be possible to lower this operation when `jax.__version__ < jaxlib.__version__`. I think this is okay given our compatibility guarantees, but I'd love a sanity check on that!

Another note, this doesn't actually change the lowered HLO for this op, so we don't need to worry about export compatibility.

PiperOrigin-RevId: 664680250
2024-08-19 01:05:31 -07:00
Dan Foreman-Mackey
b6306e3953 Remove synchronization from GPU LU decomposition kernel by adding an async batch pointers builder.
In the batched LU decomposition in cuBLAS, the output buffer is required to be a pointer of pointers to the appropriate batch matrices. Previously this reshaping was done on the host and then copied to the device, requiring a synchronization, but it seems straightforward to instead implement a tiny CUDA kernel to do this work. This definitely isn't a bottleneck or a high priority change, but this seemed like a reasonable time to fix a longstanding TODO.

PiperOrigin-RevId: 663686539
2024-08-16 04:37:09 -07:00
Paweł Paruzel
acacf8884e Determine LAPACK workspace during QR Factorization Kernel runtime
PiperOrigin-RevId: 663641199
2024-08-16 01:20:50 -07:00
Tomás Longeri
020513f300 [Mosaic] Update serde to handle upstream MLIR changes
For changes from
5f26497da7

PiperOrigin-RevId: 663020509
2024-08-14 12:48:29 -07:00
Dan Foreman-Mackey
ad1bd38790 Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.

PiperOrigin-RevId: 662945116
2024-08-14 09:20:40 -07:00
jax authors
807dcb5a06 Integrate LLVM at llvm/llvm-project@c8b5d30f70
Updates LLVM usage to match
[c8b5d30f7077](https://github.com/llvm/llvm-project/commit/c8b5d30f7077)

PiperOrigin-RevId: 662906261
2024-08-14 07:09:53 -07:00
Jevin Jiang
8f23392a8c [Mosaic:TPU] Refactor relayout helper functions to take ctx instead of only target shape.
PiperOrigin-RevId: 662672417
2024-08-13 15:22:46 -07:00
Jevin Jiang
2dea3d6a0c [Mosaic:TPU] Add shuffled load and store.
we also emulate shuffled store using (store + shuffled load + store) for previous generations.

PiperOrigin-RevId: 662657663
2024-08-13 14:41:16 -07:00
Paweł Paruzel
354293da48 Activate Singular Value Decomposition to XLA's FFI
PiperOrigin-RevId: 662436635
2024-08-13 02:41:57 -07:00
Paweł Paruzel
5fc992e5e1 Determine LAPACK workspaces during SVD kernel runtime
The SVD kernel implementation used to require workspace shapes to be determined prior to the custom call on the JAX's side. The new FFI kernels need not demand these shapes to be specified anymore. They are evaluated during kernel runtime.

PiperOrigin-RevId: 662413273
2024-08-13 01:17:44 -07:00
Dan Foreman-Mackey
3c014a4c27 Add support for shape polymorphism with lu_pivots_to_permutation.
This is needed to land support for shape polymorphism with LU decomposition more generally. Most of this change just involves adding the appropriate tests, but I've also updated the "generic" implementation which is used for lowering on CPU to support a dynamic trailing dimension in the input (the `fori_loop` will conditionally lower to a `scan` or `while_loop` as necessary). This change doesn't affect the differentiability (this op doesn't support AD) and the behavior won't change when static shapes are used.

PiperOrigin-RevId: 662024940
2024-08-12 03:39:54 -07:00
jax authors
be4d52b814 Merge pull request #22667 from ROCm:rocm-jax-triton-add-get_arch_detail
PiperOrigin-RevId: 662007143
2024-08-12 02:30:49 -07:00
Rahul Batra
4b7c198a1c [ROCm]: Add get_arch_details for triton kernel call 2024-08-12 06:16:27 +00:00
Tomás Longeri
77afe251e7 [Mosaic TPU][Python] Check validity of VectorLayout on init
PiperOrigin-RevId: 661226283
2024-08-09 05:28:00 -07:00
Tomás Longeri
e57a7e3f05 [Mosaic] Column shift relayouts for non-native tilings and packed types, except for (1, n) and packed
PiperOrigin-RevId: 661091012
2024-08-08 20:14:08 -07:00
Jieying Luo
751b5742fd Deprecate using build_cuda_plugin_from_source flag and rely on jaxlib_build config.
If jaxlib needs to be built from source, cuda plugin will be built from source as well.

PiperOrigin-RevId: 660926791
2024-08-08 11:58:13 -07:00
Dan Foreman-Mackey
11d9c2de2c Update GPU implementation of lu_pivots_to_permutation to infer the permutation size directly from the input dimensions, instead of using an input parameter.
I have left an `Attrs` annotation on the FFI binding to support backwards compatibility (this accepts, but ignores, and input `permuatation_size` parameter), but I'm not sure we strictly need that since this op doesn't support exporting anyways.

In anticipation of supporting shape polymorphism I added dimension checks to the kernel to match the ones in the abstract eval.

PiperOrigin-RevId: 660831000
2024-08-08 07:35:47 -07:00
Adam Paszke
04a753ad02 [Mosaic TPU] Improve an error message in case someone tries to extract a non-32-bit scalar.
PiperOrigin-RevId: 660826696
2024-08-08 07:22:10 -07:00
Adam Paszke
42fe45f34b [Mosaic TPU] Add support for removal of implicit 2nd minor for all 32-bit tilings
PiperOrigin-RevId: 660724215
2024-08-08 01:00:32 -07:00
jax authors
de02988e94 Merge pull request #22909 from ROCm:ci_fix_solver_paths
PiperOrigin-RevId: 660515208
2024-08-07 13:26:17 -07:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00