52 Commits

Author SHA1 Message Date
Peter Hawkins
84b58ec7f3 Increase minimum scipy version to 1.9.
Scipy 1.9 appears to fix some crashes on Mac ARM.

PiperOrigin-RevId: 571977068
2023-10-09 10:37:35 -07:00
jax authors
f64235acc8 Merge pull request #17453 from jakevdp:fix-version-string
PiperOrigin-RevId: 563466394
2023-09-07 10:06:32 -07:00
Jake VanderPlas
6f3f0d5e57 build: write appropriate version strings to build artifacts 2023-09-07 08:45:48 -07:00
Sharad Vikram
3baa6e7a89 Enable building jaxlib w/ Mosaic
PiperOrigin-RevId: 551159246
2023-07-26 03:59:30 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Peter Hawkins
bfa113ba60 Remove references to Python 3.8.
Remove the old build scripts/Dockerfile, since they are unused and broken.

PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Rahul Batra
b0e541a730 [ROCm]: Updates for container and build script
-Updated dockerfile.ms
	-Updated build script to switch building against XLA repo
  	-Update CI script
	-Update jaxlib setup.py to add rocm version
2023-06-19 18:13:28 +00:00
Peter Hawkins
69cf67f252 Bump the minimum CUDNN version for CUDA 12 wheels to 8.9. 2023-05-26 10:04:34 -04:00
Peter Hawkins
2b7790290b Bump minimum CUDNN version in pip installation to 8.8.
There are known wrong output bugs observed in JAX for earlier versions, in particular related to RNNs.
2023-05-25 14:46:39 -04:00
Sharad Vikram
557ca52f10 Add cuda_pip extra for jaxlib
PiperOrigin-RevId: 534957585
2023-05-24 13:19:27 -07:00
Jake VanderPlas
59e6ed213e Use ml_dtypes definition for jnp.finfo 2023-05-04 10:40:44 -07:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
8bb90b5fbe [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
PiperOrigin-RevId: 518830467
2023-03-23 05:13:39 -07:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
jax authors
c848efa11b Merge pull request #12808 from hawkinsp:py311
PiperOrigin-RevId: 481155690
2022-10-14 08:56:14 -07:00
Peter Hawkins
fb72c38e19 Add Python 3.11 as a compatible Python version. 2022-10-14 14:56:07 +00:00
Peter Hawkins
4988b3117d Drop absl-py as a jaxlib dependency.
absl-py is unused.
2022-10-14 13:57:26 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
5b0686f9ea Include ABI tag in jaxlib wheels.
Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl

This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl

i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
2022-08-17 15:15:46 +00:00
Yash Katariya
8a1b4785de Use the same jaxlib package name for nightlies. The __version__ will still contain the dev version (with datetime string in it).
PiperOrigin-RevId: 466534455
2022-08-09 18:53:36 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Parker Schuh
d8f0099f68 _mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
Peter Hawkins
a560a29e12 Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00
Peter Hawkins
d56601a896 Include mlir.transforms and mlir.passmanager in jaxlib BUILD.
These increase the binary size of jaxlib only a negligible amount, and allow running passes like canonicalization from Python.
2022-05-16 13:00:22 +00:00
Yash Katariya
46d034baab Add the nightly dev version to __version__ of jaxlib.
PiperOrigin-RevId: 448001375
2022-05-11 08:35:16 -07:00
Yash Katariya
dfb2caf31e Add nightly __version__ string if building jaxlib nightly
PiperOrigin-RevId: 447822974
2022-05-10 14:05:35 -07:00
Jeppe Klitgaard
a11f15e3ec feat: officially support Python 3.10 2022-05-07 13:43:12 +01:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00
Peter Hawkins
04369a3588 Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00
Peter Hawkins
ce7ae6bd76 Make MLIR bindings build work under Bazel.
Tested on Linux and Mac, but not Windows.
2021-11-12 12:16:32 -05:00
Peter Hawkins
11f6c535ae Add MLIR:Python bindings to jaxlib build.
PiperOrigin-RevId: 407657331
2021-11-04 13:29:58 -07:00
Yash Katariya
93fe3ab492 Replace _ with - because wheel.py normalizes it to .
PiperOrigin-RevId: 404049619
2021-10-18 13:47:43 -07:00
Yash Katariya
e6e81ba885 Add Cuda 11.4 with cudnn 8.2 and cudnn 8.0.5 release builds
PiperOrigin-RevId: 403661187
2021-10-16 16:13:43 -07:00
Sergei Lebedev
2a994bdb02 Type stubs for jaxlib.xla_extension no longer use -stubs suffix
PEP-561 does not specify whether subpackages of a non-stub-only-package
could use the -stubs suffix. setuptools seems to allow that, yet mypy fails
to resolve the subpackage with a -stubs suffix.

This commit makes jaxlib.xla_extension a ~normal package with a toplevel
__init__.pyi.
2021-08-02 14:31:11 +01:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
Peter Hawkins
01d6e32c7f Add version constraints to flatbuffers versions.
Require 1.12 or newer, because we've only tested 1.12 and 2.0.
Require less than 3.0, because flatbuffers uses semantic versioning and version 3.0 would mean an incompatible change has been made.
2021-05-10 20:58:22 -04:00
Peter Hawkins
c983d3c660 Bundle libdevice.10.bc with jaxlib wheels.
libdevice.10.bc is a redistributable part of the CUDA SDK.

This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
2021-04-29 10:26:03 -04:00
Sergei Lebedev
225ffc30d8 Re-exported tensorflow...xla_extension type stubs in jaxlib
The type stubs allow using precise types for XLA primitives instead
of aliasing them to Any.

This commit does not change any type annotations within JAX. That will
be done in a followup. I have manually verified that type stubs are
discoverable by mypy once the new jaxlib is installed by type "checking"

    from jaxlib import xla_extension as xe
    d: xe._Dtype
2021-04-06 14:51:45 +01:00
Peter Hawkins
13f3819054 Update README.md for jaxlib 0.1.60.
Bump jaxlib version to 0.1.61 and update changelog.

Change jaxlib numpy version limit to >=1.16 for next release. Releases older than 1.16 are deprecated per NEP 00029. Reenable NumPy 1.20.

Bump minimum jaxlib version to 0.1.60.
2021-02-03 20:44:01 -05:00