381 Commits

Author SHA1 Message Date
Rahul Batra
31d8f62826 Sytrd solver and SytrdDescriptor should NOT be CUDA only 2022-11-11 22:41:51 +00:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
46368e4e73 [sparse] Update the guard of cusparse SpMM and SpMv algorithms to cusparse version 11.7.1 onwards.
PiperOrigin-RevId: 486051658
2022-11-03 21:39:52 -07:00
Tianjian Lu
ef0f64ec5c [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 485441349
2022-11-01 16:01:50 -07:00
Jake VanderPlas
06c1d8efb5 Rollback of:
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.

Still breaks CUDA 11.1

PiperOrigin-RevId: 485151807
2022-10-31 14:38:47 -07:00
Tianjian Lu
66e75edd0b [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Peter Hawkins
0814770601 Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Tianjian Lu
e219d55c36 Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1
PiperOrigin-RevId: 482897448
2022-10-21 15:06:17 -07:00
Tianjian Lu
7093142f61 [sparse] Update the default cuSparse matvec algorithm in jaxlib.
PiperOrigin-RevId: 482553550
2022-10-20 11:49:09 -07:00
Peter Hawkins
5617a02fa4 Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
jax authors
c848efa11b Merge pull request #12808 from hawkinsp:py311
PiperOrigin-RevId: 481155690
2022-10-14 08:56:14 -07:00
Peter Hawkins
fb72c38e19 Add Python 3.11 as a compatible Python version. 2022-10-14 14:56:07 +00:00
Peter Hawkins
4988b3117d Drop absl-py as a jaxlib dependency.
absl-py is unused.
2022-10-14 13:57:26 +00:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
a3a2206d49 Fix compilation failure in lapack kernel under msan.
a_size wasn't defined, but it would only be caught under memory sanitizer.

PiperOrigin-RevId: 480176934
2022-10-10 14:24:59 -07:00
Peter Hawkins
2246887f7b Add input-output aliasing annotations for LAPACK calls on CPU.
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Peter Hawkins
22cd50535b Reapply: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.

PiperOrigin-RevId: 480134780
2022-10-10 11:29:18 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
Artem Belevich
2de91d26b7 Handle FP8 types.
PiperOrigin-RevId: 479148993
2022-10-05 14:48:30 -07:00
Rohit Santhanam
b815ac9d8e [ROCm] Upgrade to ROCm 5.3 and associated enhancements 2022-10-01 04:45:26 -07:00
Jake VanderPlas
6cae54f82d Fix bazel build alias 2022-09-26 15:13:12 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
fd90f40c45 Merge pull request #12443 from cloudhan:fix-mlir-chlo-stablehlo-symbols
PiperOrigin-RevId: 475808753
2022-09-21 06:12:44 -07:00
Cloud Han
3fa2c933f4 Fix linker error due to chlo and stablehol symbols are not exported in mlir dll 2022-09-21 17:26:21 +08:00
Eugene Burmako
1338864c1f Change TensorFlow to depend on StableHLO instead of vendoring it
This makes handling TensorFlow's dependency on StableHLO consistent with handling other TensorFlow's dependencies. For example, LLVM goes into //third_party/llvm, and so should StableHLO.

Users of tensorflow/tensorflow (e.g. JAX) need to change Bazel builds, replacing `@org_tensorflow//tensorflow/compiler/xla/mlir_hlo/stablehlo` with `@stablehlo//`. Nothing else changes, e.g. C++ includes, C++ usage, Python bindings and Python usage all stay the same. Example: https://github.com/google/jax/pull/12174.

Users of tensorflow/mlir-hlo are unaffected thanks to the awesome power of Copybara. There are minor changes in the StableHLO part of MLIR-HLO caused by the fact that the StableHLO repository and the vendored StableHLO inside tensorflow/tensorflow have diverged a little bit (e.g. Markdown formatting is slightly different between repositories because I didn't have the time to propagate these changes) and now they have been forced to converge, but these changes won't affect the behavior of neither CMake nor Bazel builds of MLIR-HLO.

Moving forward, contributions to StableHLO will only be possible through openxla/stablehlo. This is because tensorflow/tensorflow no longer vendors StableHLO. (tensorflow/mlir-hlo still does, but it's readonly).

PiperOrigin-RevId: 474360128
2022-09-14 12:25:55 -07:00
Jake VanderPlas
13a7034e6a Internal change
PiperOrigin-RevId: 474331907
2022-09-14 10:39:38 -07:00
Tianjian Lu
3243e23aa5 [sparse] Lower batch-mode bcoo_dot_genernal to cusparseSpMM.
PiperOrigin-RevId: 473777597
2022-09-12 10:09:41 -07:00
Skye Wanderman-Milne
031f0b1a10 Add missing Google-internal option to jax_test 2022-09-07 18:31:02 -07:00
jax authors
4847c0929e Init local variables
It's UB to use uninitialized values as arguments.

PiperOrigin-RevId: 471084634
2022-08-30 14:04:25 -07:00
David Dunleavy
a8aa774a57 Use tensorflow/compiler/xla/stream_executor instead of tensorflow/stream_executor
PiperOrigin-RevId: 470804752
2022-08-29 13:46:20 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Eugene Burmako
2186268ec7 Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.

This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.

C++:
  1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
  2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
  3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
  4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.

Python:
  5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
  6) Python imports don't change.
  7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
  8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:35:23 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Gordian Edenhofer
024ae47e79 Switch from pocketfft to ducc
All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
2022-08-26 13:36:25 +00:00
Rohit Santhanam
82adc6a1d0 [ROCm] Enhance hipsparse to reach parity with cusparse based on commit d37b711dd4. 2022-08-21 22:46:00 +00:00
Tianjian Lu
d37b711dd4 [sparse] Add batch count and batch stride to matrix descriptors.
PiperOrigin-RevId: 468760351
2022-08-19 12:26:17 -07:00
Peter Hawkins
0839958459 Be more selective about which MLIR pieces we build.
Reduces the size of the installed jaxlib by around 20MB.
2022-08-18 22:16:41 +00:00
jax authors
840c96692e Internal change
PiperOrigin-RevId: 468509799
2022-08-18 11:39:07 -07:00
jax authors
d933c8c427 Merge pull request #11978 from hawkinsp:import
PiperOrigin-RevId: 468478116
2022-08-18 09:34:57 -07:00
Deniz Oktay
d5de596d17 Sparse direct solver via QR factorization CUDA implementation.
PiperOrigin-RevId: 468467698
2022-08-18 08:46:25 -07:00
Peter Hawkins
1e241dcf16 Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
2022-08-18 15:22:49 +00:00
Peter Hawkins
3bb0030014 Revert: Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
Reason: Breaks JAX tests.
PiperOrigin-RevId: 468346430
2022-08-17 18:54:29 -07:00
Deniz Oktay
2bc3e39cd9 Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
PiperOrigin-RevId: 468303019
2022-08-17 15:10:27 -07:00
Peter Hawkins
5b0686f9ea Include ABI tag in jaxlib wheels.
Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl

This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl

i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
2022-08-17 15:15:46 +00:00
Mehdi Amini
ae6e0e0950 Move tensorflow/core/platform/{default, google, windows} to tensorflow/tsl/platform/...
PiperOrigin-RevId: 468025286
2022-08-16 14:33:22 -07:00
Yash Katariya
8a1b4785de Use the same jaxlib package name for nightlies. The __version__ will still contain the dev version (with datetime string in it).
PiperOrigin-RevId: 466534455
2022-08-09 18:53:36 -07:00