130 Commits

Author SHA1 Message Date
Sergei Lebedev
2d8a20c413 Do not load Triton bindings eagerly in jax/lib/__init__.py
Triton is only used by Pallas, so it makes sense to delay loading until Pallas
is imported.

PiperOrigin-RevId: 598131836
2024-01-13 03:01:02 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Sergei Lebedev
ba10775eda Added a compatibility overlay for Triton Python APIs
Follow up changes will gradually re-implement these APIs using the MLIR
builders added in google/jax#19159.

PiperOrigin-RevId: 597023799
2024-01-09 13:13:56 -08:00
Sergei Lebedev
e6c890171b Generate Python bindings for the Triton MLIR dialect
The bindings are not yet included in the jaxlib wheel. I will do that in a
follow up PR.

PiperOrigin-RevId: 595174466
2024-01-02 11:55:05 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Peter Hawkins
d95084dbc8 Use an explicit MLIR dialect registration, rather than _site_initialize_0.
Remove some special case handling of the SCF dialect, use upstream utilities instead.

PiperOrigin-RevId: 588433245
2023-12-06 08:19:55 -08:00
Peter Hawkins
720ff42cbf [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib.
Cleanup only, NFC intended.

PiperOrigin-RevId: 588074047
2023-12-05 08:05:17 -08:00
Peter Hawkins
32fb1b4034 Remove the ml_program MLIR dialect from jaxlib.
Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something.

PiperOrigin-RevId: 587323808
2023-12-02 09:29:39 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jieying Luo
d6c5910105 [PJRT C API] Move cuda_plugin_extension from jaxlib to jax-cuda-plugin (the package for cuda kernels).
PiperOrigin-RevId: 583406466
2023-11-17 09:11:46 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Peter Hawkins
9404518201 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built.
This is intended to flag cases where the wrong CUDA libraries are used, either because:
* the user self-installed CUDA and that installation is too old, or
* the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version.

PiperOrigin-RevId: 568910422
2023-09-27 11:28:40 -07:00
Jieying Luo
0e24b90043 [PJRT C API] Register custom callback for xla_python_gpu_callback in plugin module.
PiperOrigin-RevId: 568671822
2023-09-26 15:54:10 -07:00
Jieying Luo
385cfc86e6 Use ModuleNotFoundError when importing cuda_plugin_extension module to be more specific.
Therefore other ImportError will not be silenced.

PiperOrigin-RevId: 568645824
2023-09-26 14:14:48 -07:00
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
Yash Katariya
b8eccb13f0 Remove the date check from jaxlib and jax version checks since it causes problem when jaxlib runs ahead of jax in CI (depending on timezones).
PiperOrigin-RevId: 563614108
2023-09-07 19:52:35 -07:00
Yash Katariya
2e132433a3 Check that jaxlib's nightly date is not greater than jax's nightly date in check_jaxlib_version
PiperOrigin-RevId: 558191583
2023-08-18 10:54:52 -07:00
Sharad Vikram
d872812a35 [Pallas] Upstream pallas to JAX
PiperOrigin-RevId: 552963029
2023-08-01 16:43:13 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
Sharad Vikram
3d556b7a19 Add Mosaic to Jaxlib and expose bindings in jax.experimental.mosaic
PiperOrigin-RevId: 549801858
2023-07-20 18:28:51 -07:00
Peter Hawkins
f7eef2eda8 Use the upstream MLIR strip-debuginfo pass instead of hand-rolling our own.
(I had missed that the upstream pass exists!)

Fixes https://github.com/google/jax/issues/16649

PiperOrigin-RevId: 548192839
2023-07-14 12:24:59 -07:00
Jake VanderPlas
b9c7b9bb4f Remove obsolete jaxlib version checks 2023-07-12 11:53:55 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
a861b31e3e Remove redundant stablehlo import.
The duplicate import confuses pytype.

PiperOrigin-RevId: 540707118
2023-06-15 15:30:50 -07:00
Jake VanderPlas
c3d3c19f0d cleanup old jaxlib version check 2023-06-06 21:51:58 -07:00
Sharad Vikram
4fb834b351 Use jaxlib version guard for triton instead of xla_extension_version
PiperOrigin-RevId: 534974834
2023-05-24 14:06:45 -07:00
Sharad Vikram
bf8ed6a543 Move triton_kernel_call_lib to jaxlib
PiperOrigin-RevId: 534934592
2023-05-24 12:11:21 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
0dbd467cea Add a C++ implementation of safe map.
Before (argument names reversed, oops, fixed in code):

```
name                                 time/op
safe_map/num_args:0/arg_lengths:1    1.43µs ± 1%
safe_map/num_args:1/arg_lengths:1    1.61µs ± 1%
safe_map/num_args:2/arg_lengths:1    1.72µs ± 0%
safe_map/num_args:5/arg_lengths:1    2.14µs ± 1%
safe_map/num_args:10/arg_lengths:1   2.87µs ± 1%
safe_map/num_args:100/arg_lengths:1  15.6µs ± 1%
safe_map/num_args:0/arg_lengths:2    1.65µs ± 0%
safe_map/num_args:1/arg_lengths:2    1.83µs ± 1%
safe_map/num_args:2/arg_lengths:2    1.97µs ± 1%
safe_map/num_args:5/arg_lengths:2    2.41µs ± 1%
safe_map/num_args:10/arg_lengths:2   3.22µs ± 2%
safe_map/num_args:100/arg_lengths:2  17.0µs ± 2%
safe_map/num_args:0/arg_lengths:3    1.83µs ± 1%
safe_map/num_args:1/arg_lengths:3    2.02µs ± 1%
safe_map/num_args:2/arg_lengths:3    2.16µs ± 1%
safe_map/num_args:5/arg_lengths:3    2.63µs ± 1%
safe_map/num_args:10/arg_lengths:3   3.48µs ± 1%
safe_map/num_args:100/arg_lengths:3  18.1µs ± 1%
```

After:
```
name                                 time/op
safe_map/num_args:0/arg_lengths:1     409ns ± 1%
safe_map/num_args:1/arg_lengths:1     602ns ± 5%
safe_map/num_args:2/arg_lengths:1     777ns ± 4%
safe_map/num_args:5/arg_lengths:1    1.21µs ± 3%
safe_map/num_args:10/arg_lengths:1   1.93µs ± 2%
safe_map/num_args:100/arg_lengths:1  14.7µs ± 0%
safe_map/num_args:0/arg_lengths:2     451ns ± 1%
safe_map/num_args:1/arg_lengths:2     652ns ± 0%
safe_map/num_args:2/arg_lengths:2     850ns ± 4%
safe_map/num_args:5/arg_lengths:2    1.32µs ± 3%
safe_map/num_args:10/arg_lengths:2   2.11µs ± 2%
safe_map/num_args:100/arg_lengths:2  16.0µs ± 1%
safe_map/num_args:0/arg_lengths:3     496ns ± 1%
safe_map/num_args:1/arg_lengths:3     718ns ± 5%
safe_map/num_args:2/arg_lengths:3     919ns ± 4%
safe_map/num_args:5/arg_lengths:3    1.43µs ± 2%
safe_map/num_args:10/arg_lengths:3   2.30µs ± 2%
safe_map/num_args:100/arg_lengths:3  17.3µs ± 1%
```
PiperOrigin-RevId: 523263207
2023-04-10 18:09:56 -07:00
Peter Hawkins
23451dc764 Merge pull request #15303 from jakevdp:lax-asarray
PiperOrigin-RevId: 520717999
2023-03-30 20:11:11 +00:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Peter Hawkins
b62f114524 Add support for using pip-installed CUDA wheels.
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Jake VanderPlas
36d2179a85 Add xla garbage collection to gc.callback 2023-03-10 10:49:25 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
d84ac2240c Remove use_stablehlo as minimum mlir_api_version >= 43
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Yash Katariya
8fd70cccb4 Delete PyShardedBuffer since it was only used for GDAs and GDA is deprecated.
PiperOrigin-RevId: 511307891
2023-02-21 14:34:23 -08:00
Skye Wanderman-Milne
7aa7e158f8 Modify JaxArrayTest.test_defragment to work on any numbers of devices
Also skip it when the PJRT C API is enabled, since the C API only supports auto defrag.

PiperOrigin-RevId: 509933635
2023-02-15 14:36:03 -08:00
Skye Wanderman-Milne
e54858522c Add back loading TPU plugin for older jaxlib versions.
This was removed in 668b82d529.

PiperOrigin-RevId: 508777939
2023-02-10 16:16:20 -08:00
Jieying Luo
668b82d529 [PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
Loading TPU PJRT plugin is moved to make_tpu_client.

This change is based on https://github.com/google/jax/pull/14011.

PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Jake VanderPlas
94af71a24c CI: fix mypy jaxlib version 2023-01-24 06:57:23 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Jieying Luo
1132098c90 [PJRT:C] Separate loading PJRT plugin from creating PJRT client.
- Add xla_client.maybe_load_pjrt_plugins which maybe load PJRT plugins from a hardcoded set.
- Call xla_client.maybe_load_pjrt_plugins in xla_bridge beforer initializing backends.
- Add binding of python method load_pjrt_plugin to LoadPjrtPlugin which does dlopen and dlsym.
- Remove loading PJRT plugin from tpu_initializer_helper.cc.
- Add an extra call to LoadPjrtPlugin when getting the PJRT_Api* to be backward compatible.

PiperOrigin-RevId: 493381393
2022-12-06 12:29:38 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00