Now that db11efab3b has landed, we're free to split up xla_extension without creating binary size problems or having to be quite so careful about cross-module dependencies. Here weakref_lru_cache has absolutely nothing to do with XLA.
There's no reason weakref_lru_cache is in the same Python extension as everything else.
PiperOrigin-RevId: 745271825
This change copies targets into jaxlib, and a subsequent change will delete them from XLA. We separate these into two phases because we cannot atomically change both JAX and XLA.
Future changes will migrate more of the C++ pieces of XLA:Python.
PiperOrigin-RevId: 739158120
jaxlib no longer includes any lowering logic, so we don't need this module anymore. Users would be better served by the APIs in JAX core like `jax.ffi` or `jax.interpreters.mlir`.
This module isn't covered by JAX's compatibility policy, so no formal deprecation period is required, but there are enough users that we should keep this warning for at least one full release cycle.
PiperOrigin-RevId: 738728721
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.
See #25196.
A few notes
* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
compilation is done via a new PjRt extension, which uses its own compilation
pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
deprecated and will be removed after 6 months as per
[compatibility guarantees] [*]
[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees
PiperOrigin-RevId: 722773884
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.
PiperOrigin-RevId: 719280601
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.
PiperOrigin-RevId: 716596848
This requires that the file providing the bindings has the same name as the
dialect it defines, since dialect search looks for a module path of the form
`<prefix>.<dialect namespace>`.
PiperOrigin-RevId: 693241875
`cuda_nvcc`, when installed e.g. via `pip` in a `venv` comes out as a
namespace package. The previous logic found the `cuda_nvcc` import but
failed because `cuda_nvcc.__file__ is None`.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.
PiperOrigin-RevId: 687003464
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.
PiperOrigin-RevId: 647328834
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.
PiperOrigin-RevId: 638132572
The stock MLIR pipeline was a good way to get the prototype off the ground, but
its default passes can be problematic. In particular, the gpu.launch is compiled
into a sequence of instructions that load the kernel onto the GPU, run the kernel
and immediately unload it again. This has the correct semantics, but loading the
kernel is both expensive and forces a synchronization point, which leads to performance
issues.
To resolve this, I implemented a new MLIR pass that finds the gpu.launch ops and splits
each function that has it into two functions: one that preloads the kernel onto the
GPU, and another one that consumes the handle produced by the previous one. We call
the first function at compile-time, while only the second one is used at run-time.
There are other overheads in MLIR's implementation of kernel launch, but I will
fix those later.
PiperOrigin-RevId: 627670773
IIRC we used to import these bindings in lib/__init__.py which is imported
as part of the top-level jax package. So, it did make sense to delay the
check until we actually need the bindings.
However, we have since moved the bindings to lib/triton.py and thus we could
move the check there.
PiperOrigin-RevId: 607196039
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.
PiperOrigin-RevId: 587323808