406 Commits

Author SHA1 Message Date
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Eugene Burmako
f337c00ed5 Remove *_mhlo compatibility shims from jaxlib
We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted.

PiperOrigin-RevId: 510820007
2023-02-19 09:03:14 -08:00
Peter Hawkins
f7734fd6a4 Limit visibility of Bazel target jax:global_device_array.
PiperOrigin-RevId: 510521459
2023-02-17 14:30:05 -08:00
Jake VanderPlas
936e4ae101 Add new argument to jax_test rule
PiperOrigin-RevId: 509952902
2023-02-15 15:45:47 -08:00
Jake VanderPlas
ddae1d00ea fix change to csr lowering rule 2023-02-13 08:39:05 -08:00
Jake VanderPlas
de8a77a3eb [sparse] implement BCSR.__matmul__ 2023-02-10 16:11:57 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
George Necula
7d452adfd3 Add support for dynamic shapes to GPU threefry2x32 custom call.
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.

PiperOrigin-RevId: 497945778
2022-12-27 04:48:26 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Eugene Burmako
ee1ad39dd1 Port type inference for 6 ops from StableHLO to MHLO
Ops:
  1) AfterAllOp: https://github.com/openxla/stablehlo/pull/708.
  2) CreateTokenOp: https://github.com/openxla/stablehlo/pull/711.
  3) DynamicUpdateSliceOp: https://github.com/openxla/stablehlo/pull/686 and https://github.com/openxla/stablehlo/pull/757.
  4) OptimizationBarrierOp: https://github.com/openxla/stablehlo/pull/575.
  5) OutfeedOp: https://github.com/openxla/stablehlo/pull/713.
  6) SendOp: https://github.com/openxla/stablehlo/pull/580.

This PR prepares for migration from producing MHLO to producing StableHLO by
aligning type inference between dialects, so that switching from one to another
doesn't need changes to calls to Python builders.

PiperOrigin-RevId: 495404149
2022-12-14 13:38:26 -08:00
George Necula
ac7740513d Raise error for unsupported shape polymorphism for custom call and fallback lowering 2022-12-14 12:31:18 +01:00
Peter Hawkins
e835739eda Remove an unnecessary include/ from pybind11 include paths.
PiperOrigin-RevId: 492016679
2022-11-30 14:20:02 -08:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
Yash Katariya
3e5a5053f4 Run GPU presubmits via bazel test on the RBE cluster. This speeds up the build + testing significantly (upto 10x).
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.

Also make GPU presubmits blocking for JAX (re-enabled it).

PiperOrigin-RevId: 491647775
2022-11-29 08:45:58 -08:00
Qiao Zhang
c54bc90bf4 Fix cudnn_header OSS BUILD dep.
PiperOrigin-RevId: 491465703
2022-11-28 15:58:55 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
Yash Katariya
a4e8df76ab Use the remote_gpu tag which is inserted by TF's workspace2 when REMOTE_GPU_TESTING=1
PiperOrigin-RevId: 490553133
2022-11-23 11:50:50 -08:00
Yash Katariya
8e270575f8 Set tf_exec_properties on OSS tests to use TF's gpu pool in the RBE cluster.
PiperOrigin-RevId: 490542399
2022-11-23 11:00:53 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Parker Schuh
0324cac888 Remove unused potrf kernels.
PiperOrigin-RevId: 489322021
2022-11-17 15:22:13 -08:00
Jake VanderPlas
e7f4fe043e jaxlib: fix mlir_hlo build rule 2022-11-16 15:42:05 -08:00
Parker Schuh
7635df84f0 Remove custom potrf kernels in favor of native XLA cholesky support.
PiperOrigin-RevId: 488525158
2022-11-14 18:45:25 -08:00
Rahul Batra
31d8f62826 Sytrd solver and SytrdDescriptor should NOT be CUDA only 2022-11-11 22:41:51 +00:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
46368e4e73 [sparse] Update the guard of cusparse SpMM and SpMv algorithms to cusparse version 11.7.1 onwards.
PiperOrigin-RevId: 486051658
2022-11-03 21:39:52 -07:00
Tianjian Lu
ef0f64ec5c [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 485441349
2022-11-01 16:01:50 -07:00
Jake VanderPlas
06c1d8efb5 Rollback of:
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.

Still breaks CUDA 11.1

PiperOrigin-RevId: 485151807
2022-10-31 14:38:47 -07:00
Tianjian Lu
66e75edd0b [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Peter Hawkins
0814770601 Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Tianjian Lu
e219d55c36 Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1
PiperOrigin-RevId: 482897448
2022-10-21 15:06:17 -07:00
Tianjian Lu
7093142f61 [sparse] Update the default cuSparse matvec algorithm in jaxlib.
PiperOrigin-RevId: 482553550
2022-10-20 11:49:09 -07:00
Peter Hawkins
5617a02fa4 Remove JAX custom call implementation of batched triangular solve.
XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation.

PiperOrigin-RevId: 482031830
2022-10-18 15:04:14 -07:00
jax authors
c848efa11b Merge pull request #12808 from hawkinsp:py311
PiperOrigin-RevId: 481155690
2022-10-14 08:56:14 -07:00
Peter Hawkins
fb72c38e19 Add Python 3.11 as a compatible Python version. 2022-10-14 14:56:07 +00:00
Peter Hawkins
4988b3117d Drop absl-py as a jaxlib dependency.
absl-py is unused.
2022-10-14 13:57:26 +00:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00
Peter Hawkins
a3a2206d49 Fix compilation failure in lapack kernel under msan.
a_size wasn't defined, but it would only be caught under memory sanitizer.

PiperOrigin-RevId: 480176934
2022-10-10 14:24:59 -07:00
Peter Hawkins
2246887f7b Add input-output aliasing annotations for LAPACK calls on CPU.
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Peter Hawkins
22cd50535b Reapply: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.

PiperOrigin-RevId: 480134780
2022-10-10 11:29:18 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
Artem Belevich
2de91d26b7 Handle FP8 types.
PiperOrigin-RevId: 479148993
2022-10-05 14:48:30 -07:00
Rohit Santhanam
b815ac9d8e [ROCm] Upgrade to ROCm 5.3 and associated enhancements 2022-10-01 04:45:26 -07:00
Jake VanderPlas
6cae54f82d Fix bazel build alias 2022-09-26 15:13:12 -07:00