129 Commits

Author SHA1 Message Date
George Necula
3bcb8d6831 Remove DUCC FFT from jaxlib
JAX has stopped generating code that uses directly
the DUCC FFT custom calls.
The 6 months backwards compatibility window has also expired.

PiperOrigin-RevId: 638132572
2024-05-28 21:12:23 -07:00
Eugene Zhulenev
d5c7ccc774 [xla:python] Add support for registering custom call targets for all XLA execution stages and for XLA FFI traits
PiperOrigin-RevId: 636963591
2024-05-24 10:34:46 -07:00
Kyle Lucke
418b68828a Automated Code Change
PiperOrigin-RevId: 635818645
2024-05-21 08:40:34 -07:00
jax authors
c3cab2e3d3 Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3
PiperOrigin-RevId: 632579101
2024-05-10 12:56:10 -07:00
Peter Hawkins
6c425338d2 Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306
PiperOrigin-RevId: 632548226
2024-05-10 11:06:38 -07:00
jax authors
0267ed0ba9 Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib.
The runfiles of the original targets were lost when the symlinked files were used.

This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved.

PiperOrigin-RevId: 632508121
2024-05-10 08:48:12 -07:00
Jieying Luo
a949ce772b Add get_device_ordinal to cuda plugin so that CUDA dependency can be removed from py_array (jaxlib).
py_array still has CUDA dependency as a fallback to keep jaxlib[cuda] working before the migration to CUDA plugin.

PiperOrigin-RevId: 629499893
2024-04-30 12:50:50 -07:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
David Dunleavy
aade591fdf Move tsl/python to xla/tsl/python
PiperOrigin-RevId: 620320903
2024-03-29 13:15:21 -07:00
Peter Hawkins
1a193ea189 Fix segfault in cuda_plugin_extension.
The nanobind switch for the GPU callback code means that we are now using the NumPy APIs rather than pybind11's clone of them. It is important to initialize the NumPy APIs before using them in each module.

PiperOrigin-RevId: 613036056
2024-03-05 18:31:50 -08:00
Peter Hawkins
feda85dff3 Replace references to xla/python/status_casters.h with xla/pjrt/status_casters.h, which its current home.
PiperOrigin-RevId: 612578488
2024-03-04 14:11:01 -08:00
David Dunleavy
be3e39ad3b Move tsl/cuda to xla/tsl/cuda
PiperOrigin-RevId: 610550833
2024-02-26 15:45:10 -08:00
Sergei Lebedev
881436240e Inlined triton.compat
We no longer need a compatibility layer, since Pallas does not use any Triton
IR building APIs.

PiperOrigin-RevId: 606948415
2024-02-14 05:23:15 -08:00
Sergei Lebedev
1e9f96a574 Include Triton files into the jaxlib wheel
This PR is based on #19368.
2024-01-16 15:28:12 +00:00
Peter Hawkins
858fd52ac0 Fix jaxlib wheel build after removal of mosaic python files.
PiperOrigin-RevId: 597536465
2024-01-11 06:21:07 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Peter Hawkins
32fb1b4034 Remove the ml_program MLIR dialect from jaxlib.
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.

PiperOrigin-RevId: 587323808
2023-12-02 09:29:39 -08:00
Jieying Luo
0e24b90043 [PJRT C API] Register custom callback for xla_python_gpu_callback in plugin module.
PiperOrigin-RevId: 568671822
2023-09-26 15:54:10 -07:00
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
John QiangZhang
997b35e1d9 Improve the gpu lowering error message if users forget link the gpu library.
PiperOrigin-RevId: 564530960
2023-09-11 16:14:18 -07:00
Peter Hawkins
88408e13ee Remove stale references to //jaxlib:setup.cfg in Bazel build.
Fixes broken jaxlib wheel build.
2023-09-03 19:18:25 +00:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
Sharad Vikram
3d556b7a19 Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
PiperOrigin-RevId: 549801858
2023-07-20 18:28:51 -07:00
Peter Hawkins
f7eef2eda8 Use the upstream MLIR strip-debuginfo pass instead of hand-rolling our own.
(I had missed that the upstream pass exists!)

Fixes https://github.com/google/jax/issues/16649

PiperOrigin-RevId: 548192839
2023-07-14 12:24:59 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
0dbd467cea Add a C++ implementation of safe map.
Before (argument names reversed, oops, fixed in code):

```
name                                 time/op
safe_map/num_args:0/arg_lengths:1    1.43µs ± 1%
safe_map/num_args:1/arg_lengths:1    1.61µs ± 1%
safe_map/num_args:2/arg_lengths:1    1.72µs ± 0%
safe_map/num_args:5/arg_lengths:1    2.14µs ± 1%
safe_map/num_args:10/arg_lengths:1   2.87µs ± 1%
safe_map/num_args:100/arg_lengths:1  15.6µs ± 1%
safe_map/num_args:0/arg_lengths:2    1.65µs ± 0%
safe_map/num_args:1/arg_lengths:2    1.83µs ± 1%
safe_map/num_args:2/arg_lengths:2    1.97µs ± 1%
safe_map/num_args:5/arg_lengths:2    2.41µs ± 1%
safe_map/num_args:10/arg_lengths:2   3.22µs ± 2%
safe_map/num_args:100/arg_lengths:2  17.0µs ± 2%
safe_map/num_args:0/arg_lengths:3    1.83µs ± 1%
safe_map/num_args:1/arg_lengths:3    2.02µs ± 1%
safe_map/num_args:2/arg_lengths:3    2.16µs ± 1%
safe_map/num_args:5/arg_lengths:3    2.63µs ± 1%
safe_map/num_args:10/arg_lengths:3   3.48µs ± 1%
safe_map/num_args:100/arg_lengths:3  18.1µs ± 1%
```

After:
```
name                                 time/op
safe_map/num_args:0/arg_lengths:1     409ns ± 1%
safe_map/num_args:1/arg_lengths:1     602ns ± 5%
safe_map/num_args:2/arg_lengths:1     777ns ± 4%
safe_map/num_args:5/arg_lengths:1    1.21µs ± 3%
safe_map/num_args:10/arg_lengths:1   1.93µs ± 2%
safe_map/num_args:100/arg_lengths:1  14.7µs ± 0%
safe_map/num_args:0/arg_lengths:2     451ns ± 1%
safe_map/num_args:1/arg_lengths:2     652ns ± 0%
safe_map/num_args:2/arg_lengths:2     850ns ± 4%
safe_map/num_args:5/arg_lengths:2    1.32µs ± 3%
safe_map/num_args:10/arg_lengths:2   2.11µs ± 2%
safe_map/num_args:100/arg_lengths:2  16.0µs ± 1%
safe_map/num_args:0/arg_lengths:3     496ns ± 1%
safe_map/num_args:1/arg_lengths:3     718ns ± 5%
safe_map/num_args:2/arg_lengths:3     919ns ± 4%
safe_map/num_args:5/arg_lengths:3    1.43µs ± 2%
safe_map/num_args:10/arg_lengths:3   2.30µs ± 2%
safe_map/num_args:100/arg_lengths:3  17.3µs ± 1%
```
PiperOrigin-RevId: 523263207
2023-04-10 18:09:56 -07:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Peter Hawkins
ab45383038 Fix build breakage from OpenXLA switch.
PiperOrigin-RevId: 516325478
2023-03-13 14:37:35 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Eugene Burmako
2186268ec7 Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.

This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.

C++:
  1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
  2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
  3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
  4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.

Python:
  5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
  6) Python imports don't change.
  7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
  8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:35:23 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Benjamin Kramer
2c72858928 Integrate LLVM at llvm/llvm-project@8aff88fd3a
Updates LLVM usage to match
[8aff88fd3a5f](https://github.com/llvm/llvm-project/commit/8aff88fd3a5f)

PiperOrigin-RevId: 461889195
2022-07-19 08:31:24 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
Peter Hawkins
02534bdff1 Move MLIR dependencies onto //jaxlib rule instead of wheel build rule.
Change in preparation for allowing testing with Bazel.

PiperOrigin-RevId: 458460128
2022-07-01 06:54:31 -07:00
Peter Hawkins
1e171ccd10 Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
2022-06-29 12:51:01 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00