170 Commits

Author SHA1 Message Date
Peter Hawkins
474dcd409d Remove code to support jaxlib < v0.6.
New minimum jaxlib_extension_version is 330.

PiperOrigin-RevId: 748853497
2025-04-17 16:44:41 -07:00
Peter Hawkins
b4629c230c Split weakref_lru_cache into its own extension.
Now that db11efab3b has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.

There's no reason weakref_lru_cache is in the same Python extension as everything else.

PiperOrigin-RevId: 745271825
2025-04-08 13:34:47 -07:00
Vladimir Belitskiy
5370ac2ec5 Remove the try/except for Shardy imports.
Shardy has been been included in JAX for a while now.

PiperOrigin-RevId: 742778405
2025-04-01 11:33:44 -07:00
Parker Schuh
be1f649b51 Expose jax._src.lib.ifrt_version which tracks the version of
third_party/tensorflow code inside jax.

PiperOrigin-RevId: 740957982
2025-03-26 17:31:08 -07:00
Parker Schuh
6033592a95 Rename xla_extension_version to jaxlib_extension_version to reflect its new
scope.

PiperOrigin-RevId: 740944270
2025-03-26 16:36:34 -07:00
Peter Hawkins
55e408471c [JAX] [XLA:Python] Migrate xla_extension and its type stubs into jaxlib.
Future changes will migrate many of its dependent modules.

PiperOrigin-RevId: 739361786
2025-03-21 18:52:54 -07:00
Peter Hawkins
a93035f625 Migrate xla_client and its Python tests out of XLA into JAX.
This change copies targets into jaxlib, and a subsequent change will delete them from XLA. We separate these into two phases because we cannot atomically change both JAX and XLA.

Future changes will migrate more of the C++ pieces of XLA:Python.

PiperOrigin-RevId: 739158120
2025-03-21 06:26:15 -07:00
Dan Foreman-Mackey
e2b6859e7d Deprecate the jaxlib.hlo_helpers submodule.
jaxlib no longer includes any lowering logic, so we don't need this module anymore. Users would be better served by the APIs in JAX core like `jax.ffi` or `jax.interpreters.mlir`.

This module isn't covered by JAX's compatibility policy, so no formal deprecation period is required, but there are enough users that we should keep this warning for at least one full release cycle.

PiperOrigin-RevId: 738728721
2025-03-20 02:52:28 -07:00
Benjamin Chetioui
837418c652 [Mosaic GPU] Remove old jaxlib version guards.
PiperOrigin-RevId: 726071956
2025-02-12 08:49:40 -08:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
Sergei Lebedev
9ee7123c39 [mosaic_gpu] Fixed mosaic_gpu-serde pass registration
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719280601
2025-01-24 06:35:54 -08:00
Sergei Lebedev
4c363766f8 [mosaic_gpu] Removed debug prints in jax._src.lib.mosaic_gpu
PiperOrigin-RevId: 717952850
2025-01-21 09:36:12 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Benjamin Chetioui
63e59c5fd7 [Mosaic GPU] Ensure that the dialect module can be loaded successfully.
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.

PiperOrigin-RevId: 693241875
2024-11-05 00:47:21 -08:00
Kristian Hartikainen
9df719f83f Fix _cuda_path for case when cuda_nvcc is a namespace package
`cuda_nvcc`, when installed e.g. via `pip` in a `venv` comes out as a
namespace package. The previous logic found the `cuda_nvcc` import but
failed because `cuda_nvcc.__file__ is None`.
2024-11-04 18:06:55 +02:00
Benjamin Chetioui
c708a04c6e [Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
Also start moving the existing C++ tests to Python.

PiperOrigin-RevId: 691729887
2024-10-31 02:47:30 -07:00
Dan Foreman-Mackey
ad1d864b05 Fix lint at head 2024-10-26 07:41:44 -04:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
jax authors
8473391467 Merge pull request #24139 from hartikainen:fix-cuda_path
PiperOrigin-RevId: 683272496
2024-10-07 12:02:29 -07:00
Kristian Hartikainen
1ea8e3c29d Update _cuda_path
- Remove jax-relative module path test
- Use `$CUDA_ROOT` environment variable if available
- Use `cuda_nvcc` module's path if installed
2024-10-07 20:32:05 +03:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Sergei Lebedev
8cb3596136 Partially rolling forward #22998
Reverts 322d0c2f31e92e68a531f95a53c3f040d6a76bdf

PiperOrigin-RevId: 670173462
2024-09-02 04:44:47 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Feng Wang
322d0c2f31 Rollback the change "Import from `mlir.dialects` lazily"
Reverts a755f1db837c464f6aa3d3111a1bc40b5ebdd37d

PiperOrigin-RevId: 663324497
2024-08-15 09:00:47 -07:00
Sergei Lebedev
a755f1db83 Import from `mlir.dialects` lazily
These imports jointly account for ~0.3s of import time internally.

PiperOrigin-RevId: 662588167
2024-08-13 11:22:41 -07:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
jax authors
00528b9858 libdevice.10.bc is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
2024-06-27 08:35:59 -07:00
Peter Hawkins
971ab0fba2 Make CuDNN SDPA API work with JAX with a CUDA plugin configuration. 2024-06-06 12:09:19 -04:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00
George Necula
3bcb8d6831 Remove DUCC FFT from jaxlib
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.

PiperOrigin-RevId: 638132572
2024-05-28 21:12:23 -07:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
27c932a3a9 Do not import from lowering in tests/pallas/pallas_test.py
This ensures that the test is importable even with a non-GPU jaxlib, which
does not have Triton dialect bindings.

PiperOrigin-RevId: 632603225
2024-05-10 14:25:17 -07:00
Adam Paszke
5a2d7a2df4 Switch Mosaic GPU to a custom pass pipeline and improve the lowering of GPU launch
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.

To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.

There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.

PiperOrigin-RevId: 627670773
2024-04-24 03:27:45 -07:00
Paul Wohlhart
6b85557cc1 Use xla_client.Device in jax.numpy.
PiperOrigin-RevId: 627507470
2024-04-23 14:32:08 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Sergei Lebedev
d434ab55d7 Handle TypeError due to | in type annotations in Triton MLIR bindings
Unfortunately, the only fix is to upgrade the jaxlib.

PiperOrigin-RevId: 609305403
2024-02-22 02:59:46 -08:00
Sergei Lebedev
b4c8b0e4fb Check if the Triton dialect bindings are available in lib/triton.py
IIRC we used to import these bindings in lib/__init__.py which is imported
as part of the top-level jax package. So, it did make sense to delay the
check until we actually need the bindings.

However, we have since moved the bindings to lib/triton.py and thus we could
move the check there.

PiperOrigin-RevId: 607196039
2024-02-14 20:49:08 -08:00
Sergei Lebedev
881436240e Inlined triton.compat
We no longer need a compatibility layer, since Pallas does not use any Triton
IR building APIs.

PiperOrigin-RevId: 606948415
2024-02-14 05:23:15 -08:00
Sergei Lebedev
2d8a20c413 Do not load Triton bindings eagerly in jax/lib/__init__.py
Triton is only used by Pallas, so it makes sense to delay loading until Pallas
is imported.

PiperOrigin-RevId: 598131836
2024-01-13 03:01:02 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Sergei Lebedev
ba10775eda Added a compatibility overlay for Triton Python APIs
Follow up changes will gradually re-implement these APIs using the MLIR
builders added in google/jax#19159.

PiperOrigin-RevId: 597023799
2024-01-09 13:13:56 -08:00
Sergei Lebedev
e6c890171b Generate Python bindings for the Triton MLIR dialect
The bindings are not yet included in the jaxlib wheel. I will do that in a
follow up PR.

PiperOrigin-RevId: 595174466
2024-01-02 11:55:05 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
d95084dbc8 Use an explicit MLIR dialect registration, rather than _site_initialize_0.
Remove some special case handling of the SCF dialect, use upstream utilities instead.

PiperOrigin-RevId: 588433245
2023-12-06 08:19:55 -08:00
Peter Hawkins
720ff42cbf [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib.
Cleanup only, NFC intended.

PiperOrigin-RevId: 588074047
2023-12-05 08:05:17 -08:00
Peter Hawkins
32fb1b4034 Remove the ml_program MLIR dialect from jaxlib.
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.

PiperOrigin-RevId: 587323808
2023-12-02 09:29:39 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jieying Luo
d6c5910105 [PJRT C API] Move cuda_plugin_extension from jaxlib to jax-cuda-plugin (the package for cuda kernels).
PiperOrigin-RevId: 583406466
2023-11-17 09:11:46 -08:00