419 Commits

Author SHA1 Message Date
Sharad Vikram
3c3fa042e3 Copy seq_lengths before creating descriptor
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Peter Hawkins
b62f114524 Add support for using pip-installed CUDA wheels.
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Anish Tondwalkar
8081031c90 [jaxlib] fix build w/ depenency on stablehlo_serialization
PiperOrigin-RevId: 519120624
2023-03-24 05:42:38 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
8bb90b5fbe [XLA:Python] Change JAX and the XLA Python extension to get NumPy bfloat16/float8 types from ml_dtypes.
PiperOrigin-RevId: 518830467
2023-03-23 05:13:39 -07:00
Yash Katariya
88584290aa Remove GDA tests from JAX since GDA is deprecated. There are jax.Array tests for all the corresponding GDA tests
PiperOrigin-RevId: 516881635
2023-03-15 11:34:57 -07:00
Peter Hawkins
ab45383038 Fix build breakage from OpenXLA switch.
PiperOrigin-RevId: 516325478
2023-03-13 14:37:35 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
e4b154b660 Split basearray into separate Bazel module.
Move the definition of ArrayLike into basearray to avoid a cyclic dependency between array.py and basearray.

PiperOrigin-RevId: 516264828
2023-03-13 11:14:41 -07:00
Peter Hawkins
d58be3d4df Split source_info_util into its own Bazel target.
PiperOrigin-RevId: 515646269
2023-03-10 08:41:06 -08:00
Peter Hawkins
0e05a7987f Split some submodules out of //jax under Bazel.
Add separate BUILD targets
* :version - for version.py
* _src/lib - wrapping the jaxlib shims.
* :util - for util.py
* :config - for config.py

PiperOrigin-RevId: 515307923
2023-03-09 05:27:34 -08:00
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Eugene Burmako
f337c00ed5 Remove *_mhlo compatibility shims from jaxlib
We introduced these shims when migrating from MHLO to StableHLO, and they helped accommodate the version skew between jaxlib and JAX across different environments. Now that a sufficient amount of time has passed, these shims are no longer used anywhere and can be deleted.

PiperOrigin-RevId: 510820007
2023-02-19 09:03:14 -08:00
Peter Hawkins
f7734fd6a4 Limit visibility of Bazel target jax:global_device_array.
PiperOrigin-RevId: 510521459
2023-02-17 14:30:05 -08:00
Jake VanderPlas
936e4ae101 Add new argument to jax_test rule
PiperOrigin-RevId: 509952902
2023-02-15 15:45:47 -08:00
Jake VanderPlas
ddae1d00ea fix change to csr lowering rule 2023-02-13 08:39:05 -08:00
Jake VanderPlas
de8a77a3eb [sparse] implement BCSR.__matmul__ 2023-02-10 16:11:57 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
George Necula
7d452adfd3 Add support for dynamic shapes to GPU threefry2x32 custom call.
In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the
value n=-1, and the actual desired output length will be passed as
an additional operand. If the shape is static then the length will be
passed as part of the descriptor.

PiperOrigin-RevId: 497945778
2022-12-27 04:48:26 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Eugene Burmako
ee1ad39dd1 Port type inference for 6 ops from StableHLO to MHLO
Ops:
  1) AfterAllOp: https://github.com/openxla/stablehlo/pull/708.
  2) CreateTokenOp: https://github.com/openxla/stablehlo/pull/711.
  3) DynamicUpdateSliceOp: https://github.com/openxla/stablehlo/pull/686 and https://github.com/openxla/stablehlo/pull/757.
  4) OptimizationBarrierOp: https://github.com/openxla/stablehlo/pull/575.
  5) OutfeedOp: https://github.com/openxla/stablehlo/pull/713.
  6) SendOp: https://github.com/openxla/stablehlo/pull/580.

This PR prepares for migration from producing MHLO to producing StableHLO by
aligning type inference between dialects, so that switching from one to another
doesn't need changes to calls to Python builders.

PiperOrigin-RevId: 495404149
2022-12-14 13:38:26 -08:00
George Necula
ac7740513d Raise error for unsupported shape polymorphism for custom call and fallback lowering 2022-12-14 12:31:18 +01:00
Peter Hawkins
e835739eda Remove an unnecessary include/ from pybind11 include paths.
PiperOrigin-RevId: 492016679
2022-11-30 14:20:02 -08:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
Yash Katariya
3e5a5053f4 Run GPU presubmits via bazel test on the RBE cluster. This speeds up the build + testing significantly (upto 10x).
But run the continuous builds by building on RBE and testing locally so as to run the multiaccelerator tests too. Locally we have 4 GPUs available.

Also make GPU presubmits blocking for JAX (re-enabled it).

PiperOrigin-RevId: 491647775
2022-11-29 08:45:58 -08:00
Qiao Zhang
c54bc90bf4 Fix cudnn_header OSS BUILD dep.
PiperOrigin-RevId: 491465703
2022-11-28 15:58:55 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
Yash Katariya
a4e8df76ab Use the remote_gpu tag which is inserted by TF's workspace2 when REMOTE_GPU_TESTING=1
PiperOrigin-RevId: 490553133
2022-11-23 11:50:50 -08:00
Yash Katariya
8e270575f8 Set tf_exec_properties on OSS tests to use TF's gpu pool in the RBE cluster.
PiperOrigin-RevId: 490542399
2022-11-23 11:00:53 -08:00
jax authors
d1fbdbc1cf Rollback of "Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module."
PiperOrigin-RevId: 490499003
2022-11-23 07:48:05 -08:00
Qiao Zhang
78963b6020 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
2022-11-22 18:53:29 -08:00
Parker Schuh
0324cac888 Remove unused potrf kernels.
PiperOrigin-RevId: 489322021
2022-11-17 15:22:13 -08:00
Jake VanderPlas
e7f4fe043e jaxlib: fix mlir_hlo build rule 2022-11-16 15:42:05 -08:00
Parker Schuh
7635df84f0 Remove custom potrf kernels in favor of native XLA cholesky support.
PiperOrigin-RevId: 488525158
2022-11-14 18:45:25 -08:00
Rahul Batra
31d8f62826 Sytrd solver and SytrdDescriptor should NOT be CUDA only 2022-11-11 22:41:51 +00:00
Peter Hawkins
352b042fe9 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction.
Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility.

Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper.

PiperOrigin-RevId: 487621469
2022-11-10 13:16:21 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Tianjian Lu
46368e4e73 [sparse] Update the guard of cusparse SpMM and SpMv algorithms to cusparse version 11.7.1 onwards.
PiperOrigin-RevId: 486051658
2022-11-03 21:39:52 -07:00
Tianjian Lu
ef0f64ec5c [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 485441349
2022-11-01 16:01:50 -07:00
Jake VanderPlas
06c1d8efb5 Rollback of:
[sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.

Still breaks CUDA 11.1

PiperOrigin-RevId: 485151807
2022-10-31 14:38:47 -07:00
Tianjian Lu
66e75edd0b [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib.
PiperOrigin-RevId: 484351696
2022-10-27 14:34:44 -07:00
Peter Hawkins
0814770601 Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge.
PiperOrigin-RevId: 484026031
2022-10-26 11:40:14 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Tianjian Lu
e219d55c36 Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1
PiperOrigin-RevId: 482897448
2022-10-21 15:06:17 -07:00
Tianjian Lu
7093142f61 [sparse] Update the default cuSparse matvec algorithm in jaxlib.
PiperOrigin-RevId: 482553550
2022-10-20 11:49:09 -07:00