1281 Commits

Author SHA1 Message Date
Ruturaj4
29a1cb766e [ROCM] add missing typename keyword to work with gcc 2024-09-23 14:42:01 -05:00
Jevin Jiang
6b93b35842 [Mosaic:TPU] Efficient relayout with internal scratch
We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.

PiperOrigin-RevId: 676982957
2024-09-20 15:00:58 -07:00
Adam Paszke
99195ead83 [Mosaic TPU] Try reducing sublane tiling to support more vector.shape_casts
In particular, 32-bit values should now support all reshapes that do not modify the
last dimension.

PiperOrigin-RevId: 676855401
2024-09-20 08:36:22 -07:00
Dan Foreman-Mackey
bc80ecbbe4 Remove forward compatibility checks from cholesky_update lowering.
The forward compatibility window has ended and it should be safe to remove these checks.

PiperOrigin-RevId: 676853740
2024-09-20 08:32:25 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dan Foreman-Mackey
afaa3bf43c Port GPU kernels for SVD to the FFI.
Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges:

1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications.

2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel.

Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from https://github.com/jax-ml/jax/issues/23413.

PiperOrigin-RevId: 676839182
2024-09-20 07:34:50 -07:00
Jevin Jiang
47b177bd03 [Mosaic TPU][NFC] Remove FailureOr in getNativeVregOrVmaskTypeImpl
PiperOrigin-RevId: 676566796
2024-09-19 14:35:41 -07:00
Georg Stefan Schmid
d0338f5d13 [ffi] Support handler bundles in GPU plugin extension 2024-09-19 14:51:02 +00:00
Peter Hawkins
922e652c05 Replace plat-name with plat_name.
The former seems to elicit a deprecation warning from setuptools
recently.
2024-09-18 15:17:49 +00:00
jax authors
4e6f690724 Merge pull request #23653 from apaszke:torchsaic
PiperOrigin-RevId: 675967844
2024-09-18 06:35:15 -07:00
Adam Paszke
611ad63060 Add basic PyTorch integration for Mosaic GPU
We have already had most of the relevant pieces and we only needed
to connect them together. The most sensitive change is perhaps that
I needed to expose one more symbol from the XLA GPU plugin, but I don't
think it should be a problem.
2024-09-18 12:55:23 +00:00
Dan Foreman-Mackey
c61e49cd4a Simplify logic in jaxlib FFI_ASSIGN_OR_RETURN macro, and fix gcc build.
In https://github.com/google/jax/issues/23687, it was reported that recent jaxlib changes introduced issues when building from source using gcc, instead of the clang build that we test. I'm not 100% sure why the previous macro didn't work, but in investigating I found a version that seems to work on both clang and gcc with simpler logic.

PiperOrigin-RevId: 675641259
2024-09-17 11:23:25 -07:00
Jevin Jiang
8d93e101b9 [Mosaic TPU] Propagate the memory space change for memref bitcast and reshape.
PiperOrigin-RevId: 674067380
2024-09-12 17:14:41 -07:00
Jevin Jiang
178fb03050 [Mosaic TPU] Better error message when shape of memref bitcast is invalid.
PiperOrigin-RevId: 674062237
2024-09-12 16:56:50 -07:00
Dan Foreman-Mackey
a3bf75e442 Refactor gpusolver kernel definitions into separate build target.
There is a lot of boilerplate required for each new custom call to cuSolver / cuBLAS, and having both the FFI logic and the framework wrappers in the same library was getting unwieldy. This change adds a new "interface" target which just includes the shims to wrap cuSolver/BLAS functions, and then these are used from `solver_kernels_ffi` where the FFI logic lives.

PiperOrigin-RevId: 673832309
2024-09-12 07:11:36 -07:00
jax authors
79dabe530c Merge pull request #23462 from hawkinsp:mlir
PiperOrigin-RevId: 672188226
2024-09-07 21:04:56 -07:00
Dan Foreman-Mackey
7266e338c8 Update FFI target name for syrk operation to be consistent with other kernels.
PiperOrigin-RevId: 671870569
2024-09-06 13:21:38 -07:00
Dan Foreman-Mackey
1d12a9934c Port GPU kernel for symmetric eigendecomposition to GPU.
Of note, I moved the logic about which algorithm to use, and when to use the batched algorithm into the kernel in order to support shape polymorphism and export.

PiperOrigin-RevId: 671853879
2024-09-06 12:23:04 -07:00
jax authors
f97bfc85a3 Implement symmetric_product() to produce a symmetric matrix: C = alpha * X @ X.T + beta * C
PiperOrigin-RevId: 671845818
2024-09-06 11:58:20 -07:00
Sebastian Bodenstein
e3b8177af3 Internal change.
PiperOrigin-RevId: 671583042
2024-09-05 18:42:22 -07:00
Peter Hawkins
27e19239ca Fix triton capi_objects target to depend on MLIR CAPIIRObjects bazel
target.

"...Objects" targets should only depend on other "...Objects" targets in
MLIR land. Don't mix them.
2024-09-06 01:06:27 +00:00
Jevin Jiang
dba674153e [Mosaic TPU] Fix operands order in try canonicalize add of matmul.
PiperOrigin-RevId: 671437100
2024-09-05 11:06:57 -07:00
Adam Paszke
8feab68209 [Mosaic GPU] Remove the unnecessary scratch space operand
And clean up the C++ dispatch code. We don't use HBM scratch anymore
since we pass TMA descriptors as kernel arguments.

PiperOrigin-RevId: 671327420
2024-09-05 04:57:52 -07:00
Sergei Lebedev
f3b91b2042 Export PointerType and register_dialect from jaxlib.triton.dialect
The `... as ...` form tells the type checker that the name is exported.
See #7570.

PiperOrigin-RevId: 671318047
2024-09-05 04:15:32 -07:00
Paweł Paruzel
2082662bb1 Port Hessenberg Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 671283487
2024-09-05 01:59:32 -07:00
Jevin Jiang
c1d3c2db9f [Mosaic TPU] Fix mosaic alignment check in concatenate rule.
PiperOrigin-RevId: 670837792
2024-09-03 22:57:27 -07:00
Paweł Paruzel
414eb90f5b Activate Householder Product to XLA's FFI
PiperOrigin-RevId: 670196460
2024-09-02 06:19:01 -07:00
Peter Hawkins
1ab3119d43 Add some msan suppressions to the LAPACK symmetric eigendecomposition FFI call.
This fixes some msan false positives in our CI, since we do not msan-instrument Fortran code.

PiperOrigin-RevId: 669385248
2024-08-30 11:12:45 -07:00
Paweł Paruzel
4342c0c0f3 Determine LAPACK workspace during Householder Product Kernel runtime
Workspace dependency was removed, and the info parameter is ignored now.

PiperOrigin-RevId: 669246058
2024-08-30 02:06:16 -07:00
Sergei Lebedev
7dd9adba05 Fixed stack-use-after-scope in Mosaic GPU
PiperOrigin-RevId: 668958750
2024-08-29 09:07:58 -07:00
Jevin Jiang
a3cccd34e2 [Mosaic TPU] Print expected Mosaic version after finding unsupported version.
PiperOrigin-RevId: 668632116
2024-08-28 15:33:31 -07:00
Jevin Jiang
b01075054a [Mosaic TPU] Support memref bitcast.
If element bitwidth changes, the ratio of bitwidth is multiplied to the 2nd minormost dim size and the leading dim in tiling. For example, we can bitcast Memref<8x128xf32> with tiling (8, 128) to Memref<16x128xi16> with tiling (16, 128).

PiperOrigin-RevId: 668619683
2024-08-28 15:00:46 -07:00
Kevin Gleason
78d5b75b0d Trim StableHLO python binding dependencies
With proper CAPI in place these dependencies are no longer needed, llvm support needed for string ostream for string APIs.

PiperOrigin-RevId: 668476145
2024-08-28 09:01:15 -07:00
Paweł Paruzel
3c6103f2df Activate Eigenvalue Decompositions to XLA's FFI
Two eigenvalue decomposition methods. One is intended for non-symmetric matrices - GEEV (General Eigenvalue Solver) - and the other for Symmetric or Hermitian matrices - SYEVD/HEEVD.

PiperOrigin-RevId: 668381949
2024-08-28 03:53:49 -07:00
Peter Hawkins
45b871950e Fix a number of minor problems in the ROCM build.
Change in preparation for adding more presubmits for AMD ROCM.

PiperOrigin-RevId: 667766343
2024-08-26 17:04:01 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Kevin Gleason
d72104de59 Use StableHLO filegroup for python APIs in jaxlib MLIR build.
PiperOrigin-RevId: 666450684
2024-08-22 12:36:39 -07:00
Adam Paszke
9c3f2dcefc [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest
XLA runtime creates a context per device, so we need to make sure that a kernel is loaded
separately on each device.

PiperOrigin-RevId: 666353098
2024-08-22 08:06:37 -07:00
Dan Foreman-Mackey
b56ed8eedd Port GPU kernel for Householder transformation to FFI.
PiperOrigin-RevId: 666305682
2024-08-22 05:23:09 -07:00
Paweł Paruzel
4786930a4c Determine LAPACK workspace during Eigenvalue Kernels runtime
PiperOrigin-RevId: 666285759
2024-08-22 04:09:34 -07:00
Paweł Paruzel
a72d46c549 Ignore LAPACK info parameter for QR Factorization
The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call.

PiperOrigin-RevId: 666241215
2024-08-22 01:38:38 -07:00
Krishna Haridasan
3713b966c2 Fix a potential segfault in triton kernel call caching
It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call
in case there is an error later during initialization. This change updates the cache to store either
an error or a valid kernel call.

PiperOrigin-RevId: 666161091
2024-08-21 20:45:35 -07:00
Benjamin Kramer
0105254ab1 Unbreak Mosaic after 42944da5ba
PiperOrigin-RevId: 665973530
2024-08-21 11:59:09 -07:00
jax authors
6a2a96c3b8 Merge pull request #23166 from ROCm:ci_bazel_build
PiperOrigin-RevId: 665935164
2024-08-21 10:27:51 -07:00
Ruturaj4
c41d644886 [ROCm] Fix bazel build issue 2024-08-21 08:40:54 -05:00
Dan Foreman-Mackey
e51848ea3d Activate GPU kernel for LU decomposition.
This adds support for shape polymorphism and export for this custom call, and adds the appropriate tests.

One of the biggest changes here is to move all the lowing logic for the getrf call into jax (lax/linalg.py) instead of in jaxlib (gpu_solver.py and lapack.py) since the lowering code is now identical for CPU and GPU (the only difference is the handler names).

PiperOrigin-RevId: 665829252
2024-08-21 05:08:41 -07:00
jax authors
607ee3eea1 Merge pull request #23143 from vfdev-5:more-nogil-mlir-dialects
PiperOrigin-RevId: 665464355
2024-08-20 11:58:28 -07:00
vfdev-5
da77b710b8 Added py::mod_gil_not_used() to PYBIND11_MODULE for _triton_ext and _tpu_ext
Description:
- Added `py::mod_gil_not_used()` to `PYBIND11_MODULE` for `_triton_ext` and `_tpu_ext`.

Refs:
- https://py-free-threading.github.io/porting/#__tabbed_1_2

Context:
- https://github.com/google/jax/issues/23073
2024-08-20 15:08:36 +02:00
Dan Foreman-Mackey
bd90968a25 Port the GPU Cholesky update custom call to the FFI.
PiperOrigin-RevId: 665319689
2024-08-20 05:46:03 -07:00