952 Commits

Author SHA1 Message Date
Eugene Zhulenev
d49a0c5a63 [jax] Remove dead code from JAX custom calls defined as FFI handlers
PiperOrigin-RevId: 651025363
2024-07-10 08:11:12 -07:00
Eugene Zhulenev
1e03917c43 [xla:ffi] Use lazy decoding for Buffer<dtype,rank>
name                old cpu/op   new cpu/op   delta
BM_AnyBufferArgX1   11.0ns ± 3%  11.2ns ±10%   +1.76%  (p=0.000 n=67+69)
BM_AnyBufferArgX4   12.4ns ± 3%  12.4ns ± 4%   -0.31%  (p=0.006 n=69+69)
BM_BufferArgX1      12.5ns ± 1%  11.1ns ± 4%  -11.20%  (p=0.000 n=62+76)
BM_BufferArgX4      19.1ns ± 1%  14.4ns ± 4%  -24.84%  (p=0.000 n=64+73)
BM_BufferArgX8      36.0ns ± 5%  20.3ns ± 4%  -43.59%  (p=0.000 n=79+75)
BM_TupleOfI32Attrs  66.4ns ± 1%  66.4ns ± 2%   -0.03%  (p=0.000 n=66+72)

PiperOrigin-RevId: 650691450
2024-07-09 11:07:25 -07:00
Justin Fu
0cb82cea65 [Pallas] Add better reduction support.
Adds lowering rules for reduce_all, reduce_any, reduce_min, and reductions to scalars.

PiperOrigin-RevId: 650689871
2024-07-09 11:03:17 -07:00
Paweł Paruzel
4e1a66ea21 Avoid throwing exceptions in LAPACK kernel code
PiperOrigin-RevId: 650569943
2024-07-09 03:57:50 -07:00
jax authors
0da9b69285 Use default tiling in scratch buffers if XLA enables it
PiperOrigin-RevId: 650493683
2024-07-08 22:49:10 -07:00
Tomás Longeri
5c7c29bc6e [Mosaic] Remove restriction of offsets falling in first tile of vreg, start rolling out op support for it, starting with vector.extract_strided_slice
VectorLayout offsets are now allowed to fall anywhere within the vreg slice. This way, tiling is still applied after offsets and offsets are still applied after implicit dimensions.
Note that offsets outside of the vreg slice would mean a vreg full of padding, which is why we disallow them.

PiperOrigin-RevId: 650408597
2024-07-08 16:23:10 -07:00
Paweł Paruzel
532be68461 Port Singular Value Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 650212574
2024-07-08 05:19:53 -07:00
jax authors
2561ba5d37 Introduce a canonicalize pass, rewrite all contractions as matmuls (vector::ContractionOp as tpu::MatMulOp), remove special handling for contraction op in other passes.
PiperOrigin-RevId: 649205635
2024-07-03 14:49:35 -07:00
Vadym Matsishevskyi
f089ecc47a Fix gpu_jax_head_jaxlib_pypi_latest job after migrating to plugin structure for jaxlib dependency
PiperOrigin-RevId: 648863763
2024-07-02 15:35:32 -07:00
Jevin Jiang
484d09f4af [Pallas][Mosaic] Relax dynamic index on 2nd minor dim in load/store.
We support any dynamic index on 2nd minor dim in either of the cases:
1. The minormost dim size of a unsliced memref matches VREG lane count.
2. Load/store one row on the second minormost dim, which triggers implicit strided load/store.

Note: For the default cases which can not skip the alignment check, we still use dynamic slice + static load/store solution to reduce scalar core work. We should figure out a way to optimize this in all cases.
PiperOrigin-RevId: 648771794
2024-07-02 10:52:11 -07:00
jax authors
cad751fc6d Merge pull request #22233 from ROCm:ci_remove_mosaic_dep
PiperOrigin-RevId: 648751831
2024-07-02 09:58:09 -07:00
jax authors
fcaeea4876 Merge pull request #22232 from ROCm:ci_typed_xla_ffi
PiperOrigin-RevId: 648746459
2024-07-02 09:38:10 -07:00
jax authors
65e237ed9d Merge pull request #22231 from ROCm:ci_absl_status
PiperOrigin-RevId: 648744927
2024-07-02 09:34:44 -07:00
Adam Paszke
265a54da31 [Mosaic GPU] Pass in TMA descriptors through kernel parameters
As we've established (sigh) we can't pass in TMA descriptors through global memory.
The current workaround was to use constant memory instead, but this raises a number of
potential concurrency issues. So, instead, we use the freshly added support for grid_constant
parameters in upstream LLVM to pass the descriptors as kernel arguments. This seems to work
fine and should in fact have lower overheads than both previous methods.

PiperOrigin-RevId: 648744363
2024-07-02 09:30:52 -07:00
Ruturaj4
332435e028 [ROCM] make mosaic dependency cuda specific 2024-07-02 11:05:42 -05:00
Ruturaj4
58b658cfb8 [ROCM] add typed XLA FFI support in rocm specific code 2024-07-02 11:04:43 -05:00
Ruturaj4
4b936233a6 [ROCM] Replace xla::Status with absl::Status in rocm_plugin_extension.cc 2024-07-02 10:59:53 -05:00
George Necula
2f808e9da9 Fix error in custom call registration for some FFI functions
We are getting the following errors:
```
Duplicate FFI handler registration for cu_threefry2x32_ffi on a platform CUDA
Duplicate FFI handler registration for cu_lu_pivots_to_permutation on a platform CUDA
```

It seems that with the ffi registration mechanism based on `XLA_FFI_REGISTER_HANDLER` it is not possible anymore to
register a call target twice.

The fix here is to rollback the changes in https://github.com/google/jax/pull/22178
and disable the changes from https://github.com/google/jax/pull/20997.

PiperOrigin-RevId: 647993991
2024-06-29 12:18:34 -07:00
Dan Foreman-Mackey
9b33df6438 Update C++ registration of cu_lu_pivots_to_permutation to use XLA_FFI_REGISTER_HANDLER
PiperOrigin-RevId: 647734115
2024-06-28 10:53:33 -07:00
George Necula
cbe524298c Ported threefry2x32 for GPU to the typed XLA FFI
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.

For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.

PiperOrigin-RevId: 647657084
2024-06-28 06:24:44 -07:00
Tomás Longeri
3a21c81eac [Mosaic] Fix bug in VectorLayout::generalizes after cl/647395486
PiperOrigin-RevId: 647603239
2024-06-28 02:14:04 -07:00
Tomás Longeri
10e598a3fc [Mosaic] In VectorLayout::generalizes, for (1, n) tiling, we can always squeeze out a 2nd minor dimension
PiperOrigin-RevId: 647395486
2024-06-27 11:52:02 -07:00
Dan Foreman-Mackey
98b87540a7 Avoid throwing exceptions in LAPACK CPU kernels.
When an FFI kernel is executed, there isn't any global try/except block (I think!) so it's probably a good idea to avoid throwing.
Instead, it should be safer to handle mapping failures to ffi::Error manually.

PiperOrigin-RevId: 647348889
2024-06-27 09:41:07 -07:00
Dan Foreman-Mackey
9ae1c56c44 Update lu_pivots_to_permutation to use FFI dimensions on GPU.
The XLA FFI interface provides metadata about buffer dimensions, so quantities
like batch dimensions can be evaluated on the backend, instead of passed as
attributes. This change has the added benefit of allowing this FFI call to
support "vectorized" vmap and dynamic shapes.

PiperOrigin-RevId: 647343656
2024-06-27 09:27:15 -07:00
Christos Perivolaropoulos
ea49194926 [msoaic_gpu] Control dumping mlir with MOSAIC_GPU_DUMP_MLIR_PASSES
PiperOrigin-RevId: 647341364
2024-06-27 09:17:52 -07:00
jax authors
00528b9858 libdevice.10.bc is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
2024-06-27 08:35:59 -07:00
jax authors
9df105c18f Pass the assigned layout to infer_memref_layout for correct memref
layout.

PiperOrigin-RevId: 647323218
2024-06-27 08:16:00 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
jax authors
96cf5d53c8 Merge pull request #21916 from ROCm:ci_pjrt
PiperOrigin-RevId: 646793145
2024-06-26 02:43:21 -07:00
Tomás Longeri
94c5d0d747 [Mosaic][apply-vector-layout] Fix possible segfault in arith.extsi/arith.extf after cl/644495447
This only happens for layout pairs that are never inferred.

PiperOrigin-RevId: 646303509
2024-06-24 19:54:08 -07:00
Tomás Longeri
21bf3d196d [Mosaic][Python] Define __repr__ for VectorLayout
Loosely follows the example MLIR's bindings for Attribute

PiperOrigin-RevId: 646270865
2024-06-24 17:18:15 -07:00
Tomás Longeri
097806a033 [Mosaic][Python] Include error message in exceptions
PiperOrigin-RevId: 646259787
2024-06-24 16:36:26 -07:00
Peter Hawkins
7f24837eef Update minimum NumPy version to v1.24. 2024-06-21 15:17:17 -07:00
Tomás Longeri
a730f6bfd3 [Mosaic][infer-vector-layout] Allow non-32-bit types for vector.extract_strided_slice
PiperOrigin-RevId: 645481424
2024-06-21 13:17:37 -07:00
Kyle Lucke
84d748f43c Stop using xla/statusor.h now that it just contains an alias for absl::Status.
In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free.

PiperOrigin-RevId: 645169743
2024-06-20 15:09:40 -07:00
Kyle Lucke
80de35514c Replace xla::Status with absl::Status in cuda_plugin_extension.cc.
PiperOrigin-RevId: 645144526
2024-06-20 13:51:03 -07:00
Paweł Paruzel
63aab133f1 Port LU Decomposition to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 644845277
2024-06-19 17:31:25 -07:00
Chris Jones
de8fd3b00d [mosaic:gpu] Fix MLIR canonicalization pass region-simplify option.
`region-simplify` now has `normal` and `aggressive` modes (using `normal` for now).

PiperOrigin-RevId: 644724434
2024-06-19 06:02:11 -07:00
Jevin Jiang
cac1791f7c [XLA:Mosaic] Support dynamic roll
We will choose the best solution based on the size of internal scratch memory.
- Sol 1: Convert dynamic roll to Log(N) static ops
- Sol 2: Static Store + Dynamic Load with internal scratch

PiperOrigin-RevId: 644509328
2024-06-18 14:18:56 -07:00
Jevin Jiang
c180b86bbd [XLA:Mosaic] Fix ext rule with large native tile.
PiperOrigin-RevId: 644495447
2024-06-18 13:34:40 -07:00
Ruturaj4
a00d030248 [ROCM] nits and fixes 2024-06-18 20:21:23 +00:00
Jevin Jiang
ed4958cb3e [XLA:Mosaic] Add internal scratch VMEM
- Make internal scratch size configurable.
- Pass the number of max sublanes allowed in scratch to apply-vector-layout pass.
- Create a helper function to fetch internal scratch VMEM address.

PiperOrigin-RevId: 644184896
2024-06-17 17:31:31 -07:00
Kyle Lucke
ebdafea9c8 Stop using xla/status.h, xla:status, and xla::Status now that xla::Status is just an alias for an absl::Status
PiperOrigin-RevId: 644063768
2024-06-17 10:51:55 -07:00
jax authors
f86cd6de56 Rewrite vector.multi_dim_reduction with bf16 source/accumulator/output into
a multi_dim_reduction with f32 source/accumulator/output, where the source
and accumulator are extended and the result is truncated. This addressed 'only
32-bit reductions supported' error.

PiperOrigin-RevId: 644062786
2024-06-17 10:51:24 -07:00
Ruturaj4
99c2b7b4e9 [ROCm] Bring-up pjrt support 2024-06-17 16:49:22 +00:00
Adam Paszke
4ea73bf787 Use constant memory to pass in TMA descriptors to the kernel
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.

Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.

For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays

PiperOrigin-RevId: 643972675
2024-06-17 05:31:26 -07:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
Paweł Paruzel
3d39b6e752 Port Cholesky Factorization to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 642954763
2024-06-13 05:44:36 -07:00
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00