64 Commits

Author SHA1 Message Date
Bart Chrzaszcz
75e22f2ccd #sdy Run inlined mesh lifter pass at the end of JAX lowering.
PiperOrigin-RevId: 685728692
2024-10-14 09:13:12 -07:00
Kevin Gleason
78d5b75b0d Trim StableHLO python binding dependencies
With proper CAPI in place these dependencies are no longer needed, llvm support needed for string ostream for string APIs.

PiperOrigin-RevId: 668476145
2024-08-28 09:01:15 -07:00
Kevin Gleason
d72104de59 Use StableHLO filegroup for python APIs in jaxlib MLIR build.
PiperOrigin-RevId: 666450684
2024-08-22 12:36:39 -07:00
vfdev-5
da77b710b8 Added py::mod_gil_not_used() to PYBIND11_MODULE for _triton_ext and _tpu_ext
Description:
- Added `py::mod_gil_not_used()` to `PYBIND11_MODULE` for `_triton_ext` and `_tpu_ext`.

Refs:
- https://py-free-threading.github.io/porting/#__tabbed_1_2

Context:
- https://github.com/google/jax/issues/23073
2024-08-20 15:08:36 +02:00
vfdev-5
b1b3ea276b Added py::mod_gil_not_used() to PYBIND11_MODULE register_jax_dialects 2024-08-20 00:03:56 +02:00
Tomás Longeri
77afe251e7 [Mosaic TPU][Python] Check validity of VectorLayout on init
PiperOrigin-RevId: 661226283
2024-08-09 05:28:00 -07:00
Jevin Jiang
59e944dadf [XLA:Mosaic] Pass rewrite ctx of apply-vector-layout pass to relayout function.
We will implement a more efficient relayout according to the configs in rewrite ctx, such as `hardware_generation`, `max_sublanes_in_scratch` and so on. So it makes sense to change the relayout interface to take ctx (including python bindings). Now we can define rewrite ctx in `apply_vector_layout_test` as well. It makes it easier to test some advanced stuff (eg., mxu_shape change, max_sublanes_in_scratch change for rotate and relayout).

PiperOrigin-RevId: 655350013
2024-07-23 16:50:45 -07:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
Tomás Longeri
21bf3d196d [Mosaic][Python] Define __repr__ for VectorLayout
Loosely follows the example MLIR's bindings for Attribute

PiperOrigin-RevId: 646270865
2024-06-24 17:18:15 -07:00
Tomás Longeri
097806a033 [Mosaic][Python] Include error message in exceptions
PiperOrigin-RevId: 646259787
2024-06-24 16:36:26 -07:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
jax authors
5f702674f7 Merge pull request #21103 from superbobry:mosaic-gpu-fix
PiperOrigin-RevId: 631521771
2024-05-07 13:11:43 -07:00
Sergei Lebedev
8ccbebae4b Fixed Mosaic GPU build following #21029 2024-05-07 17:08:00 +01:00
jax authors
e691c19bb2 Merge pull request #21029 from superbobry:jaxlib-mlir-pyi
PiperOrigin-RevId: 629836927
2024-05-01 14:22:21 -07:00
Sergei Lebedev
442526869f Bundle MLIR .pyi files with jaxlib
This allows mypy and pyright to type check the code using MLIR Python APIs.
2024-05-01 19:37:26 +01:00
Adam Paszke
4051ac2a2f [Mosaic GPU] Only call kernel initializer from inside a custom call
XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...

Also fix up build_wheel.py to match the rename of the runtime lib.

PiperOrigin-RevId: 629401858
2024-04-30 07:10:05 -07:00
Adam Paszke
9b0319512a [Mosaic GPU] Use a custom TMA descriptor initialization method
The one bundled with the default MLIR runtime was convenient, but it is also
impractical. It allocates memory (which can deadlock due to NCCL), does a
synchronous host-to-device copy and then leaks the descriptor after the kernel...

With this change, we use our own runtime function to create all the descriptors.
What's more, we pack them all into a single buffer so that a single asynchronous
copy is sufficient. Finally, we use a scratch output to allocate the scratch buffer,
letting us lean on XLA:GPU for memory management.

PiperOrigin-RevId: 628430358
2024-04-26 09:40:47 -07:00
Adam Paszke
ded9272a5b [Mosaic GPU] Implement a simple profilng tool using CUDA events
The other JAX profiling tools are a little heavyweight when we only care about
timing a single kernel programatically.

Also adapt wgmma.py to match failures triggered by upstream MLIR changes.

PiperOrigin-RevId: 628096973
2024-04-25 09:18:39 -07:00
Adam Paszke
340b9e3739 Update GPU and NVGPU MLIR bindings to match upstream MLIR changes
Upstream MLIR Python bindings now require two more extension libraries
to work properly. The dialects fail to import without this change.
2024-04-25 11:41:19 +00:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
David Dunleavy
cd2b91c398 Update references to TSL config_settings to their new home in XLA
PiperOrigin-RevId: 623249851
2024-04-09 12:36:10 -07:00
Sandeep Dasgupta
6ffd55c405 Fixing StableHLO python dependencies on stablehlo:reference_api
PiperOrigin-RevId: 618294054
2024-03-22 14:52:06 -07:00
Peter Hawkins
ef40b85c8b Don't build the Triton MLIR dialect on Windows
This dialect doesn't build on Windows, but we don't support GPUs on Windows anyway, so we can simply exclude it from the build.

CI failures look like this:
```
C:\npm\prefix\bazel.CMD run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=C:\a\jax\jax\jax\dist --jaxlib_git_hash=5f19f7712b485493ac141c44eea3b3eb1ffdfb59 --cpu=AMD64
b"external/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(70): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2672: 'mlir::OpState::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(110): error C2783: 'enable_if<llvm::function_traits<decay<FnT>::type,std::is_class<T>::value>::num_args==1,RetT>::type mlir::OpState::walk(FnT &&)': could not deduce template argument for 'RetT'\r\n        with\r\n        [\r\n            T=decay<FnT>::type\r\n        ]\r\nexternal/llvm-project/mlir/include\\mlir/IR/OpDefinition.h(165): note: see declaration of 'mlir::OpState::walk'\r\nexternal/llvm-project/mlir/include\\mlir/IR/PatternMatch.h(357): error C2872: 'detail': ambiguous symbol\r\nexternal/llvm-project/mlir/include\\mlir/Rewrite/FrozenRewritePatternSet.h(15): note: could be 'mlir::detail'\r\nbazel-out/x64_windows-opt/bin/external/triton/include\\triton/Dialect/Triton/IR/Ops.h.inc(5826): note: or       'mlir::triton::detail'\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(712): note: see reference to class template instantiation 'mlir::OpRewritePattern<mlir::scf::ForOp>' being compiled\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2672: 'mlir::Block::walk': no matching overloaded function found\r\nexternal/triton/lib/Dialect/TritonGPU/Transforms/Utility.cpp(741): error C2783: 'RetT mlir::Block::walk(FnT &&)': could not deduce template argument for 'ArgT'\r\nexternal/llvm-project/mlir/include\\mlir/IR/Block.h(289): note: see declaration of 'mlir::Block::walk'\r\n"
    output = subprocess.check_output(cmd)
```
PiperOrigin-RevId: 609153322
2024-02-21 16:02:54 -08:00
Sergei Lebedev
6a7d1dceff Added ir.Value-based versions of load and store in triton.compat
PiperOrigin-RevId: 606597830
2024-02-13 06:13:31 -08:00
Sergei Lebedev
5e2e609a9b _triton_ext no longer links in MLIR C APIs
I re-used the same trick we do for the TPU dialect. Specifically, _triton_ext no longer depends on :triton_dialect_capi. Instead

* we include Triton dialect C bindings into :jaxlib_mlir_capi_objects
* and _triton_ext depends on :jaxlib_mlir_capi_objects and a header-only cc_library providing Triton dialect C bindings

This is a fork of #19680 with a few internal-only fixes.

PiperOrigin-RevId: 604929377
2024-02-07 03:39:29 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Christian Sigg
c83fd971a0 Fix jax mlir python dependency build after 537b2aa264
PiperOrigin-RevId: 593370604
2023-12-23 21:02:29 -08:00
Peter Hawkins
560187334a Add register_jax_dialects to jaxlib wheel.
Fixes build breakage.
2023-12-06 19:07:04 +00:00
Peter Hawkins
d95084dbc8 Use an explicit MLIR dialect registration, rather than _site_initialize_0.
Remove some special case handling of the SCF dialect, use upstream utilities instead.

PiperOrigin-RevId: 588433245
2023-12-06 08:19:55 -08:00
Peter Hawkins
32fb1b4034 Remove the ml_program MLIR dialect from jaxlib.
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.

PiperOrigin-RevId: 587323808
2023-12-02 09:29:39 -08:00
Peter Hawkins
50c7223ed1 Fix Windows build failure.
The TPU extension didn't build because the MLIR Python binding code requires pybind11 to be included first on Windows, per 9584f58344/mlir/include/mlir-c/Bindings/Python/Interop.h (L24)

PiperOrigin-RevId: 587049246
2023-12-01 10:31:53 -08:00
Adam Paszke
ffbd632fb6 Add type annotations to avoid initializer list issues on macOS
Also remove the vector-avoiding specialization. For some reason
is_same<ssize_t, int64_t> evaluates to true on macOS, but then
the compiler complains that int64_t is a long long, while
ssize_t is only a long.
2023-11-27 18:02:50 +00:00
Tomás Longeri
f35ddc8c68 Fix bad cast in tpu_ext.cc
The argument to the cast is of type ssize_t. Mismatch between int64_t and ssize_t happens in Mac and causes build to fail:
`error: const_cast from 'const pybind11::ssize_t *' (aka 'const long *') to 'int64_t *' (aka 'long long *') is not allowed`

PiperOrigin-RevId: 584457599
2023-11-21 16:23:27 -08:00
Tomás Longeri
f12216908d [Mosaic] In Python bindings, fix getDefaultInsertionPoint and change import path for mlir.ir
PiperOrigin-RevId: 584258833
2023-11-21 02:10:20 -08:00
Tomás Longeri
f602fbe997 [Mosaic] Python bindings for VectorLayout, VRegDataBounds, assemble, disassemble, relayout and apply_layout_op
PiperOrigin-RevId: 584220887
2023-11-20 22:44:40 -08:00
Tomás Longeri
c186928a3e [Mosaic] Don't link CAPIIR into _tpu_ext, link into jaxlib_mlir_capi_shared_library instead
PiperOrigin-RevId: 579881376
2023-11-06 10:14:43 -08:00
Skye Wanderman-Milne
a03d6e6613 Move _tpu_ext.cc to jaxlib/mlir/_mlir_libs and set RPATH correctly
_tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in
jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in
its RPATH or similar on other platforms.

We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it
can use the same linkopts as other mlir targets that depend on
libjaxlib_mlir_capi.so. In particular, we want this to work correctly
across platforms, and it's not clear if Windows supports RPATH-like
functionality beyond the current directory.

PiperOrigin-RevId: 551372130
2023-07-26 18:25:17 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
Peter Hawkins
f7eef2eda8 Use the upstream MLIR strip-debuginfo pass instead of hand-rolling our own.
(I had missed that the upstream pass exists!)

Fixes https://github.com/google/jax/issues/16649

PiperOrigin-RevId: 548192839
2023-07-14 12:24:59 -07:00
Peter Hawkins
fed159a9eb Avoid duplicate symbol definition.
Fixes https://github.com/google/jax/issues/16525
2023-07-12 08:03:21 -04:00
Eugene Burmako
8696bef218 Integrate StableHLO at openxla/stablehlo@14691ce
Manual changes:
  * stablehlo/integrations/python/mlir/dialects/stablehlo.py: to keep around get_earliest_forward_compatible_version while it's still used in JAX.

PiperOrigin-RevId: 533140501
2023-05-18 08:42:26 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Eugene Burmako
b8dfb97e57 Integrate StableHLO at openxla/stablehlo@7a93924
PiperOrigin-RevId: 521293524
2023-04-02 11:14:01 -07:00
Anish Tondwalkar
8081031c90 [jaxlib] fix build w/ depenency on stablehlo_serialization
PiperOrigin-RevId: 519120624
2023-03-24 05:42:38 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Jake VanderPlas
e7f4fe043e jaxlib: fix mlir_hlo build rule 2022-11-16 15:42:05 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
fd90f40c45 Merge pull request #12443 from cloudhan:fix-mlir-chlo-stablehlo-symbols
PiperOrigin-RevId: 475808753
2022-09-21 06:12:44 -07:00
Cloud Han
3fa2c933f4 Fix linker error due to chlo and stablehol symbols are not exported in mlir dll 2022-09-21 17:26:21 +08:00