909 Commits

Author SHA1 Message Date
Jevin Jiang
ed4958cb3e [XLA:Mosaic] Add internal scratch VMEM
- Make internal scratch size configurable.
- Pass the number of max sublanes allowed in scratch to apply-vector-layout pass.
- Create a helper function to fetch internal scratch VMEM address.

PiperOrigin-RevId: 644184896
2024-06-17 17:31:31 -07:00
Kyle Lucke
ebdafea9c8 Stop using xla/status.h, xla:status, and xla::Status now that xla::Status is just an alias for an absl::Status
PiperOrigin-RevId: 644063768
2024-06-17 10:51:55 -07:00
jax authors
f86cd6de56 Rewrite vector.multi_dim_reduction with bf16 source/accumulator/output into
a multi_dim_reduction with f32 source/accumulator/output, where the source
and accumulator are extended and the result is truncated. This addressed 'only
32-bit reductions supported' error.

PiperOrigin-RevId: 644062786
2024-06-17 10:51:24 -07:00
Adam Paszke
4ea73bf787 Use constant memory to pass in TMA descriptors to the kernel
To work around another buggy part of the PTX documentation. While PTX
explicitly says that TMA descriptors can be in global memory, the C++
programming guide heavily discurages this, because it can lead to
incorrrect results. Which is also what we've sometimes observed as
a cache coherency issue unless a TMA fence is explicitly inserted at the
beginning of the kernel.

Note that this approach has a big downside of making the kernel unsafe
for concurrent use. I don't think that XLA:GPU will ever dispatch it
concurrently so I didn't insert any extra synchronization for now, but
we should seriously consider it. My hope at the moment is that we'll
be able to start passing in TMA descs as kernel args soon (pending
upstreaming LLVM changes...) and we won't have to deal with this again.

For the programming guide, see: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#using-tma-to-transfer-multi-dimensional-arrays

PiperOrigin-RevId: 643972675
2024-06-17 05:31:26 -07:00
Peter Hawkins
b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
Paweł Paruzel
3d39b6e752 Port Cholesky Factorization to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 642954763
2024-06-13 05:44:36 -07:00
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
Jieying Luo
ad9f35ae53 [PJRT:PLUGIN] Support both string and bytes as the input type of function name for register_custom_call_target in jax-cuda-plugin.
PiperOrigin-RevId: 642639867
2024-06-12 09:30:57 -07:00
jax authors
a0e5e0f411 Integrate LLVM at llvm/llvm-project@c012e487b7
Updates LLVM usage to match
[c012e487b724](https://github.com/llvm/llvm-project/commit/c012e487b724)

PiperOrigin-RevId: 642581785
2024-06-12 05:11:10 -07:00
Tomás Longeri
3e1e98992c [Mosaic] Handle adding singleton minor dimension that was already implicit for non-32-bit types, and do not force native tiling
Also fix extra comma in apply_vector_layout_test which was being annoying with autoformatter

PiperOrigin-RevId: 642454594
2024-06-11 18:10:26 -07:00
jax authors
d20b9e324f Integrate LLVM at llvm/llvm-project@8c5d9c79b9
Updates LLVM usage to match
[8c5d9c79b96e](https://github.com/llvm/llvm-project/commit/8c5d9c79b96e)

PiperOrigin-RevId: 642352474
2024-06-11 12:24:43 -07:00
Jevin Jiang
5b38549810 [XLA:Mosaic] No need to assume a multiple of tile if tile dim size is 1.
PiperOrigin-RevId: 642301822
2024-06-11 09:53:13 -07:00
Adam Paszke
1256ceb266 [Mosaic GPU] Rearrange the pass pipeline (again)
PiperOrigin-RevId: 642256145
2024-06-11 06:59:50 -07:00
jax authors
71c19b779d Rewrite vector.contraction with bf16 accumulator and output into a
contraction with f32 accumulator and output, where the accumulator is
extended and the output truncated. For targets that do not support bf16
matmul, the lhs and rhs are extended to f32.

PiperOrigin-RevId: 642051952
2024-06-10 16:02:46 -07:00
Jevin Jiang
53daa0c742 [XLA:Mosaic] Fix infer layout for nested loop.
- We should recursively clear layouts and any assume_layout ops if we want to override layouts in a block.
- Refactor the logic of assume layouts for block arguments to a helper function.
- Add tests for nested fori loop and while loop.

PiperOrigin-RevId: 641973011
2024-06-10 11:49:01 -07:00
Adam Paszke
0739d520b1 [Mosaic GPU] Don't always run with llvm::DebugFlag enabled
This slipped past during code review.

PiperOrigin-RevId: 641899993
2024-06-10 07:50:26 -07:00
Thomas Köppe
cd93b46df4 Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.
PiperOrigin-RevId: 641879836
2024-06-10 06:21:16 -07:00
Adam Paszke
3b4039c850 [Mosaic GPU] Load LLVM lowering interfaces for all dialects
Apparently we were missing interface registration code for LLVM lowering,
which the gpu-to-llvm pass gracefully ignores unless compiled with debug
assertions enabled. But, simply adding the assertions in fact makes the
pass _too powerful_ and makes it lower _all dialects to LLVM_, which is not
what we want. That's why I've replaced it with a minimal version that is
only repsponsible for handling the GPU dialect, making the lowering similar
to the one prior to extra registrations.

PiperOrigin-RevId: 641874183
2024-06-10 05:55:01 -07:00
Sergei Lebedev
136289e914 Added filelock to py_deps
This should unblock #21394, which uses filelock in the compilation cache.

PiperOrigin-RevId: 641338150
2024-06-07 13:16:33 -07:00
Paweł Paruzel
5fcd50b7fa Refactor kernel function assigment
PiperOrigin-RevId: 641255192
2024-06-07 08:20:31 -07:00
jax authors
f51af87fc5 fp8 matmul in pallas
PiperOrigin-RevId: 641254832
2024-06-07 08:17:06 -07:00
Tomás Longeri
a65d3ae0da [Mosaic] Expand vector.shape_cast support for sublane (un)folding no-ops
- Support non-zero minor offsets without having to relayout (they're still a no-op).
- Remove restriction on tiling which now allows 1D packed types to work.

PiperOrigin-RevId: 640967375
2024-06-06 11:35:19 -07:00
Tomás Longeri
20d9aac6be [Mosaic] Remove some restrictions for vector.shape_cast in infer-vector-layout and apply-vector-layout
- On infer-vector-layout remove some restrictions related to batch dimensions. Reshaping them doesn't matter as long as they don't combine with tiled dimensions.
- On apply-vector-layout, simplify handling of cases where the implicit tiled don't change while removing some unnecessary restrictions.
  - Don't require native tiling or natural topology for this.

PiperOrigin-RevId: 640837740
2024-06-06 03:26:43 -07:00
jax authors
0cbd0a023d Merge pull request #21494 from dfm:mac-arm-x86
PiperOrigin-RevId: 640613602
2024-06-05 12:36:00 -07:00
Dan Foreman-Mackey
0bf6700e3f Expose XLA FFI headers to bazel build and re-enable tests
This re-enables the tests removed in https://github.com/google/jax/pull/21563
and adds support for exposing the XLA FFI headers in the
`jax.extend.ffi.include_dir` directory during a bazel build. While it's
unlikely that these will be useful for most bazel users, it is good to provide
a consistent interface with the wheel build and to be able to test this feature.

PiperOrigin-RevId: 640194961
2024-06-04 10:14:43 -07:00
Adam Paszke
6a1fcc6cb2 [Mosaic TPU] Normalize inferred layouts to supported ones in matmul rule
Previously the rule would complain if the layouts were unsupported, but that's not
the right way to handle that situation. With this change, we simply pick a supported
configuration instead (and expect relayout to handle it).

PiperOrigin-RevId: 640190248
2024-06-04 10:03:00 -07:00
jax authors
c3ac4b55da Merge pull request #21632 from andportnoy:aportnoy/jax-test-data
PiperOrigin-RevId: 640168061
2024-06-04 08:50:09 -07:00
Andrey Portnoy
15dccd458c Add data argument to jax_test Bazel rule, forward to py_test 2024-06-04 11:17:30 -04:00
Tomás Longeri
8a1445a038 [Mosaic] Document tpu.create_subelement_mask
PiperOrigin-RevId: 639898224
2024-06-03 13:45:09 -07:00
Tomás Longeri
e620acfa17 [Mosaic] Remove "support" for MAXNUMF vector reductions
Pallas hasn't been using it to lower since cl/604849222, which is before we had serde, so we won't create a serde rule for this. That CL also tried to remove the support, but it had to be restored in cl/605436599 because it was breaking serialized kernels.

PiperOrigin-RevId: 639896981
2024-06-03 13:40:31 -07:00
jax authors
1edb94ec46 [XLA][Mosaic] Add support for fp8 matmuls in TPUv5+
Needed a little more backfill for TPU load

PiperOrigin-RevId: 639206243
2024-05-31 17:51:25 -07:00
Jevin Jiang
389bf93abf [XLA:Mosaic] Fix infer/apply vector layout rule for terminators (scf::yieldOp, scf::conditionOp).
We should infer layout for each terminator inside its own region and find a compatible layout for a final result if the result is based on terminators from multiple regions like scf::ifOp, scf::whileOp, scf::forOp. If no compatible layout is found, we will fall back to a normalized layout. Finally we also need to ensure the layouts in input, terminator and output are consistent across loops.

PiperOrigin-RevId: 639122434
2024-05-31 12:47:33 -07:00
Adam Paszke
d01496a09a [Mosaic GPU] Restore the PTX/PTXAS/SASS dump flags
They're very useful while prototyping the kernels.

PiperOrigin-RevId: 639027506
2024-05-31 07:27:36 -07:00
Sergei Lebedev
d2a39bc61b Updated the layer norm implementation in Mosaic GPU tests
jnp.var now needs lax.gt_p, which we don't currently support.

PiperOrigin-RevId: 639011383
2024-05-31 06:11:48 -07:00
Sergei Lebedev
8729952d82 Added a missing return to MosaicGPUCustomCall
PiperOrigin-RevId: 638627696
2024-05-30 06:13:01 -07:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
Dan Foreman-Mackey
3cd77ee45e Add a comment about x86 builds of Python to cpu_feature_guard
MacOS users frequently encounter a run-time error related to unsupported
AVX instructions (#21491, most recently). This is typically caused by
running an x86 Python build on ARM hardware, and it requires
re-installation of Python. This PR adds this suggestion to the error
message when encountered on macOS.
2024-05-29 15:25:44 -04:00
Tomás Longeri
8f8b976421 [Mosaic] Packed loads and stores with 1D tiling should use (1, 128 * packing)
There are multiple representations for 1D tiling in vector layouts and we need to choose one consistently.

PiperOrigin-RevId: 638331061
2024-05-29 10:25:07 -07:00
Tomás Longeri
a07c7816ab [Mosaic] Fix bug in VectorLayout::generalizes introduced in cl/636250759
PiperOrigin-RevId: 638253907
2024-05-29 05:48:19 -07:00
George Necula
3bcb8d6831 Remove DUCC FFT from jaxlib
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.

PiperOrigin-RevId: 638132572
2024-05-28 21:12:23 -07:00
Michael Levesque-Dion
43f51d73ce Clean up version switches from dense array migration
PiperOrigin-RevId: 637955865
2024-05-28 10:58:51 -07:00
Tomás Longeri
8b95853609 [Mosaic] Add relayout for (1, 128 * packing) -> (packing, 128).
PiperOrigin-RevId: 637951690
2024-05-28 10:47:41 -07:00
Tomás Longeri
97f9a5e80e [Mosaic] Expand vector.shape_cast no-op detection for expanding/shrinking lane shape casts
- Remove restriction on sublane tiling being 1 or a multiple of 8 on the expanded shape.
- Support packed types.

PiperOrigin-RevId: 637777493
2024-05-27 22:32:08 -07:00
Tomás Longeri
3fb9acf01a [Mosaic] Expand support of vector.broadcast
- Enable it for minor or second-minor implicit dims for the non-no-op case.
- Don't allow output offsets for broadcasted dimensions to be non-replicated. Make sure to assign them as replicated in infer-vector-layout for all cases.
- Don't fail when both tiled dimensions are logically broadcasted but only one of them requires actual broadcasting (before, it would hit the unimplemented sublane + lane broadcast case).

PiperOrigin-RevId: 637772134
2024-05-27 21:59:37 -07:00
Justin Fu
683ca2cd40 [Pallas][Mosaic] Add lowering rules for PRNG ops.
PiperOrigin-RevId: 636999151
2024-05-24 12:22:15 -07:00
Eugene Zhulenev
d5c7ccc774 [xla:python] Add support for registering custom call targets for all XLA execution stages and for XLA FFI traits
PiperOrigin-RevId: 636963591
2024-05-24 10:34:46 -07:00
Henning Becker
15a1985445 Update cuDNN to version 9.1.1 in JAX
PiperOrigin-RevId: 636956696
2024-05-24 10:10:21 -07:00
Sergei Lebedev
0a694a1b42 Bumped the minimum ml_dtypes version to 0.4.0 2024-05-23 21:51:00 +01:00
Sergei Lebedev
daa81e6fb5 Added support for printing scalar values in Pallas TPU kernels
The implementation uses the new tpu.log operation in the Mosaic TPU dialect.

Note that

* the logging only happens if --xla_tpu_enable_log_recorder is set;
* only scalars can be printed;
* placeholders only accept i32 arguments at the moment.

PiperOrigin-RevId: 636585852
2024-05-23 10:02:00 -07:00
Adam Paszke
63a13f516d [Mosaic TPU] Add support for tpu.iota over untiled dimensions
PiperOrigin-RevId: 636567090
2024-05-23 08:56:54 -07:00