65 Commits

Author SHA1 Message Date
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
Henning Becker
15a1985445 Update cuDNN to version 9.1.1 in JAX
PiperOrigin-RevId: 636956696
2024-05-24 10:10:21 -07:00
Sergei Lebedev
0a694a1b42 Bumped the minimum ml_dtypes version to 0.4.0 2024-05-23 21:51:00 +01:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Sergei Lebedev
8ccbebae4b Fixed Mosaic GPU build following #21029 2024-05-07 17:08:00 +01:00
Sergei Lebedev
442526869f Bundle MLIR .pyi files with jaxlib
This allows mypy and pyright to type check the code using MLIR Python APIs.
2024-05-01 19:37:26 +01:00
Adam Paszke
8e3f5b1018 Initial commit for Mosaic GPU
Moving this to JAX to make it easier to explore Pallas integration.

PiperOrigin-RevId: 625982382
2024-04-18 04:04:10 -07:00
Peter Hawkins
478cfa9944 Add an upper bound on JAX's CUDNN version constraint.
Major releases of CUDNN break ABI compatibility, so we cannot allow new major versions.

PiperOrigin-RevId: 620030416
2024-03-28 13:00:36 -07:00
jax authors
0be07e6aec Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5

PiperOrigin-RevId: 618911489
2024-03-25 11:46:39 -07:00
jax authors
910a31d7b7 Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e
PiperOrigin-RevId: 618249855
2024-03-22 12:10:53 -07:00
jax authors
bed4f65438 Remove support for CUDA 11.
Pin minimal required versions for CUDA to 12.1.

PiperOrigin-RevId: 618195554
2024-03-22 09:05:39 -07:00
Sergei Lebedev
1e9f96a574 Include Triton files into the jaxlib wheel
This PR is based on #19368.
2024-01-16 15:28:12 +00:00
Christian Sigg
c83fd971a0 Fix jax mlir python dependency build after 537b2aa264
PiperOrigin-RevId: 593370604
2023-12-23 21:02:29 -08:00
Peter Hawkins
84b58ec7f3 Increase minimum scipy version to 1.9.
Scipy 1.9 appears to fix some crashes on Mac ARM.

PiperOrigin-RevId: 571977068
2023-10-09 10:37:35 -07:00
jax authors
f64235acc8 Merge pull request #17453 from jakevdp:fix-version-string
PiperOrigin-RevId: 563466394
2023-09-07 10:06:32 -07:00
Jake VanderPlas
6f3f0d5e57 build: write appropriate version strings to build artifacts 2023-09-07 08:45:48 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
bfa113ba60 Remove references to Python 3.8.
Remove the old build scripts/Dockerfile, since they are unused and broken.

PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Rahul Batra
b0e541a730 [ROCm]: Updates for container and build script
-Updated dockerfile.ms
	-Updated build script to switch building against XLA repo
  	-Update CI script
	-Update jaxlib setup.py to add rocm version
2023-06-19 18:13:28 +00:00
Peter Hawkins
69cf67f252 Bump the minimum CUDNN version for CUDA 12 wheels to 8.9. 2023-05-26 10:04:34 -04:00
Peter Hawkins
2b7790290b Bump minimum CUDNN version in pip installation to 8.8.
There are known wrong output bugs observed in JAX for earlier versions, in particular related to RNNs.
2023-05-25 14:46:39 -04:00
Sharad Vikram
557ca52f10 Add cuda_pip extra for jaxlib
PiperOrigin-RevId: 534957585
2023-05-24 13:19:27 -07:00
Jake VanderPlas
59e6ed213e Use ml_dtypes definition for jnp.finfo 2023-05-04 10:40:44 -07:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
8bb90b5fbe [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
PiperOrigin-RevId: 518830467
2023-03-23 05:13:39 -07:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
jax authors
c848efa11b Merge pull request #12808 from hawkinsp:py311
PiperOrigin-RevId: 481155690
2022-10-14 08:56:14 -07:00
Peter Hawkins
fb72c38e19 Add Python 3.11 as a compatible Python version. 2022-10-14 14:56:07 +00:00
Peter Hawkins
4988b3117d Drop absl-py as a jaxlib dependency.
absl-py is unused.
2022-10-14 13:57:26 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
5b0686f9ea Include ABI tag in jaxlib wheels.
Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl

This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl

i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
2022-08-17 15:15:46 +00:00
Yash Katariya
8a1b4785de Use the same jaxlib package name for nightlies. The __version__ will still contain the dev version (with datetime string in it).
PiperOrigin-RevId: 466534455
2022-08-09 18:53:36 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Parker Schuh
d8f0099f68 _mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
Peter Hawkins
a560a29e12 Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00
Peter Hawkins
d56601a896 Include mlir.transforms and mlir.passmanager in jaxlib BUILD.
These increase the binary size of jaxlib only a negligible amount, and allow running passes like canonicalization from Python.
2022-05-16 13:00:22 +00:00
Yash Katariya
46d034baab Add the nightly dev version to __version__ of jaxlib.
PiperOrigin-RevId: 448001375
2022-05-11 08:35:16 -07:00
Yash Katariya
dfb2caf31e Add nightly __version__ string if building jaxlib nightly
PiperOrigin-RevId: 447822974
2022-05-10 14:05:35 -07:00
Jeppe Klitgaard
a11f15e3ec feat: officially support Python 3.10 2022-05-07 13:43:12 +01:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00