Peter Hawkins
96ba290faf
Jax 0.3.5 and jaxlib 0.3.5 release.
2022-04-06 23:56:41 +00:00
Peter Hawkins
bc658e7456
[MHLO] Add direct MHLO lowerings for most linear algebra kernels.
...
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Peter Hawkins
3bfa6af2c8
[MHLO] Add MHLO lowering for PRNG kernels.
...
PiperOrigin-RevId: 439919104
2022-04-06 13:23:01 -07:00
Aden Grue
8884ce5b98
Migrate 'jaxlib' CPU custom-calls to the status-returning API
...
PiperOrigin-RevId: 438165260
2022-03-29 17:14:14 -07:00
Skye Wanderman-Milne
d7087abce6
Bump jax and jaxlib versions for 0.3.2 release
...
Also add CPU pjit to changelog
2022-03-16 14:31:00 -07:00
Jake VanderPlas
765d11d50c
Fix ROCM BUILD rule
...
Fixes https://github.com/google/jax/issues/9864 ; replaces https://github.com/google/jax/issues/9870
PiperOrigin-RevId: 434554684
2022-03-14 13:39:22 -07:00
Skye Wanderman-Milne
5c8c4d487a
Update jaxlib version to 0.3.2 to match jax
2022-03-11 01:06:43 +00:00
Peter Hawkins
7d02949d24
[JAX:GPU] Implement the full_matrices=False case of SVD without generating the full matrices and then slicing.
...
PiperOrigin-RevId: 432425681
2022-03-04 05:55:36 -08:00
jax authors
cf9a900d78
Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
...
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
Jake VanderPlas
616df55ad4
Gate ROCM targets in BUILD file.
...
PiperOrigin-RevId: 432216579
2022-03-03 09:50:14 -08:00
jax authors
6c45969fe4
Integrate LLVM at llvm/llvm-project@eb27da7dec
...
Updates LLVM usage to match
[eb27da7dec67](https://github.com/llvm/llvm-project/commit/eb27da7dec67 )
PiperOrigin-RevId: 432199388
2022-03-03 08:24:39 -08:00
Jake VanderPlas
3403054b33
Fix typo: JAX_CUSPARSE_11030 -> JAX_CUSPARSE_11300
...
This is a silly typo, but it's been annoying me for months
PiperOrigin-RevId: 432078590
2022-03-02 18:32:21 -08:00
Reza Rahimi
a0d9d81f92
Update JAX to use new math libraries in ROCm-5.0.
2022-03-01 20:02:15 +00:00
Yash Katariya
2162868ed9
Update values after release
...
PiperOrigin-RevId: 427910510
2022-02-10 20:32:53 -08:00
Yash Katariya
1ad3551ec9
Release jax and jaxlib 0.3.0 as per the new release process.
...
PiperOrigin-RevId: 427809845
2022-02-10 11:59:13 -08:00
Peter Hawkins
6791446bb1
Update development jaxlib version to 0.1.77, update jaxlib version in setup.py to 0.1.76.
...
Changelog entry for jaxlib 0.1.77 was already added in a previous PR.
PiperOrigin-RevId: 424872047
2022-01-28 08:10:58 -08:00
jax authors
727823eaae
Make lapack symbols strong in lapack_kernels
...
The lapack_kernels target has dual use as a jax kernel for the lapack
functions obtained via SciPy when running in a Python context (via pybind),
and as a jax kernel for lapack functions linked directly for use in a
pure C++ context.
The prior solution to this problem was to define the lapack symbols with
the weak attribute to make the linking with lapack optional (not sure why
exactly, since SciPy uses the exact same lapack library). However, this
causes C++ applications to silently forgo the linking with lapack and
simply leave those symbols as null pointers. Whether that happens or
not seems to be dependent on link order and dependency layering. In
short, this solution does not work half of the time, for seemingly
arbitrary reasons.
This is fixed here by adding a separate shim library that lists out
the lapack symbols as strong symbols and initializes the internal
function pointers of the kernels. Linking with this new library pulls
in the correct dependencies reliably. On the python side (with SciPy),
you simply link only with the basic lapack_kernels target.
PiperOrigin-RevId: 424208059
2022-01-25 16:30:30 -08:00
Peter Hawkins
548b9446ef
Suppress memorysanitizer for code that calls LAPACK kernels.
...
PiperOrigin-RevId: 420325456
2022-01-07 10:50:29 -08:00
Peter Hawkins
04369a3588
Drop support for NumPy 1.18.
...
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.
The next NumPy deprecation will be 1.19 on Jun 21, 2022.
PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00
Peter Hawkins
eafaafd624
Add some initial filecheck tests for JAX->MHLO lowering.
...
The coverage of this test suite is not complete, but it's a start.
PiperOrigin-RevId: 415560462
2021-12-10 10:59:24 -08:00
Yash Katariya
1b5630eed6
Update jaxlib version number to 0.1.76
...
PiperOrigin-RevId: 415050863
2021-12-08 11:14:12 -08:00
Cloud Han
317edcdacd
fix mlir capi dll building and linking
2021-11-25 00:07:25 +08:00
Jake VanderPlas
a93c99d7be
[sparse] specify operand layouts in cusparse.py
...
Why? This can fix issues when inputs have non-standard layouts
PiperOrigin-RevId: 411110145
2021-11-19 11:47:38 -08:00
Peter Hawkins
7902ddaca2
Update jaxlib versions.
2021-11-17 11:46:41 -05:00
Jake VanderPlas
11094fa372
[sparse] add dtype assertions to several cusparse wrappers
2021-11-15 14:48:23 -08:00
jax authors
3ac3ec9a83
Merge pull request #8524 from reza-amd:patch-1
...
PiperOrigin-RevId: 409496068
2021-11-12 14:04:53 -08:00
Peter Hawkins
ce7ae6bd76
Make MLIR bindings build work under Bazel.
...
Tested on Linux and Mac, but not Windows.
2021-11-12 12:16:32 -05:00
Reza Rahimi
e511b280d8
fix for translation to cudaDataType
2021-11-11 17:02:59 -08:00
Peter Hawkins
11f6c535ae
Add MLIR:Python bindings to jaxlib build.
...
PiperOrigin-RevId: 407657331
2021-11-04 13:29:58 -07:00
Tianjian Lu
4814d75768
Update coo_matvec and coo_matmat comments.
2021-11-02 14:14:00 -07:00
Yash Katariya
4d8bce1b85
Add a default cuda installation path and more explicit installation paths for CUDA jaxlib.
...
```
# Installs Cuda 11 with Cudnn 8.2
$ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
$ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
PiperOrigin-RevId: 404134291
2021-10-18 19:56:22 -07:00
Yash Katariya
93fe3ab492
Replace _
with -
because wheel.py normalizes it to .
...
PiperOrigin-RevId: 404049619
2021-10-18 13:47:43 -07:00
jax authors
8a261f04d5
Merge pull request #8261 from hawkinsp:real
...
PiperOrigin-RevId: 404013628
2021-10-18 11:28:12 -07:00
Peter Hawkins
8c3b212dd6
Improve real type conversion in a couple more places.
2021-10-18 13:50:11 -04:00
Peter Hawkins
051375976a
Remove unused backward compatibility code in cusolver.py.
...
Simplify implementation of _real_type in passing.
2021-10-18 13:27:10 -04:00
Yash Katariya
e6e81ba885
Add Cuda 11.4 with cudnn 8.2 and cudnn 8.0.5 release builds
...
PiperOrigin-RevId: 403661187
2021-10-16 16:13:43 -07:00
Yash Katariya
ac0796048f
Move cuda .py files to :gpu_support so that if :gpu_support is not present, then internal jaxlib will act like a CPU jaxlib even if --config=cuda
is specified.
...
PiperOrigin-RevId: 403170945
2021-10-14 13:20:01 -07:00
Skye Wanderman-Milne
0072c32546
Update CHANGELOG and verson numbers for jaxlib 0.1.72 release
2021-10-12 17:37:29 -07:00
jax authors
90fdfbe8c1
Merge pull request #8033 from SaturdayGenfo:schur-lapack-wrapper
...
PiperOrigin-RevId: 399706351
2021-09-29 09:37:49 -07:00
Leello Tadesse Dadi
f9a246ac19
schur lapack wrapper
2021-09-29 14:29:52 +02:00
Peter Hawkins
2c2f4033cc
Move contents of jax.lib to jax._src.lib.
...
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.
PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
94f97b920f
Refactor JAX CPU kernels to make them usable from C++.
...
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.
This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.
When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.
Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.
PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00
Peter Hawkins
f1f0aefc59
Further split up cuda_prng and cuda_lu_pivot kernels to avoid exposing ABSL code to NVCC.
...
With this change, we are careful not to include any ABSL-including .cc files in cuda_library rules.
PiperOrigin-RevId: 394544751
2021-09-02 14:35:46 -07:00
Peter Hawkins
f004bcb7b8
[JAX] Refactor JAX custom kernels to split kernel implementations from Python bindings.
...
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.
The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.
There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.
PiperOrigin-RevId: 394460343
2021-09-02 07:53:09 -07:00
yashkatariya
be824a792e
Update files after new jaxlib release 0.1.71
2021-09-01 10:43:20 -07:00
Aden Grue
6b00b44807
Move all Abseil dependencies out of jaxlib CUDA libraries
...
These were breaking the build with CUDA 10.2
PiperOrigin-RevId: 391875083
2021-08-19 16:55:37 -07:00
Reza Rahimi
16a110e4ff
fix custom_call_status for rocm
2021-08-16 04:39:47 +00:00
Jake VanderPlas
f8081a9a52
[sparse] fix GPU translation rule for coo/csr matmat
2021-08-10 10:13:29 -07:00
Qiao Zhang
a93eaf3c9e
Use absl::Status::message() instead of error_message().
...
PiperOrigin-RevId: 389810033
2021-08-09 23:44:36 -07:00
Qiao Zhang
2afba31f71
Fix return value for MakeBatchPointers.
...
PiperOrigin-RevId: 389795975
2021-08-09 21:44:46 -07:00