364 Commits

Author SHA1 Message Date
Peter Hawkins
a3a2206d49 Fix compilation failure in lapack kernel under msan.
a_size wasn't defined, but it would only be caught under memory sanitizer.

PiperOrigin-RevId: 480176934
2022-10-10 14:24:59 -07:00
Peter Hawkins
2246887f7b Add input-output aliasing annotations for LAPACK calls on CPU.
PiperOrigin-RevId: 480156067
2022-10-10 12:57:29 -07:00
Peter Hawkins
22cd50535b Reapply: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version.

PiperOrigin-RevId: 480134780
2022-10-10 11:29:18 -07:00
Peter Hawkins
2693afa263 Revert: Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them.

PiperOrigin-RevId: 479670565
2022-10-07 14:36:24 -07:00
Peter Hawkins
93b839ace4 Use input-output aliasing for jaxlib GPU custom calls.
Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need.

PiperOrigin-RevId: 479642543
2022-10-07 12:22:04 -07:00
Artem Belevich
2de91d26b7 Handle FP8 types.
PiperOrigin-RevId: 479148993
2022-10-05 14:48:30 -07:00
Rohit Santhanam
b815ac9d8e [ROCm] Upgrade to ROCm 5.3 and associated enhancements 2022-10-01 04:45:26 -07:00
Jake VanderPlas
6cae54f82d Fix bazel build alias 2022-09-26 15:13:12 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
fd90f40c45 Merge pull request #12443 from cloudhan:fix-mlir-chlo-stablehlo-symbols
PiperOrigin-RevId: 475808753
2022-09-21 06:12:44 -07:00
Cloud Han
3fa2c933f4 Fix linker error due to chlo and stablehol symbols are not exported in mlir dll 2022-09-21 17:26:21 +08:00
Eugene Burmako
1338864c1f Change TensorFlow to depend on StableHLO instead of vendoring it
This makes handling TensorFlow's dependency on StableHLO consistent with handling other TensorFlow's dependencies. For example, LLVM goes into //third_party/llvm, and so should StableHLO.

Users of tensorflow/tensorflow (e.g. JAX) need to change Bazel builds, replacing `@org_tensorflow//tensorflow/compiler/xla/mlir_hlo/stablehlo` with `@stablehlo//`. Nothing else changes, e.g. C++ includes, C++ usage, Python bindings and Python usage all stay the same. Example: https://github.com/google/jax/pull/12174.

Users of tensorflow/mlir-hlo are unaffected thanks to the awesome power of Copybara. There are minor changes in the StableHLO part of MLIR-HLO caused by the fact that the StableHLO repository and the vendored StableHLO inside tensorflow/tensorflow have diverged a little bit (e.g. Markdown formatting is slightly different between repositories because I didn't have the time to propagate these changes) and now they have been forced to converge, but these changes won't affect the behavior of neither CMake nor Bazel builds of MLIR-HLO.

Moving forward, contributions to StableHLO will only be possible through openxla/stablehlo. This is because tensorflow/tensorflow no longer vendors StableHLO. (tensorflow/mlir-hlo still does, but it's readonly).

PiperOrigin-RevId: 474360128
2022-09-14 12:25:55 -07:00
Jake VanderPlas
13a7034e6a Internal change
PiperOrigin-RevId: 474331907
2022-09-14 10:39:38 -07:00
Tianjian Lu
3243e23aa5 [sparse] Lower batch-mode bcoo_dot_genernal to cusparseSpMM.
PiperOrigin-RevId: 473777597
2022-09-12 10:09:41 -07:00
Skye Wanderman-Milne
031f0b1a10 Add missing Google-internal option to jax_test 2022-09-07 18:31:02 -07:00
jax authors
4847c0929e Init local variables
It's UB to use uninitialized values as arguments.

PiperOrigin-RevId: 471084634
2022-08-30 14:04:25 -07:00
David Dunleavy
a8aa774a57 Use tensorflow/compiler/xla/stream_executor instead of tensorflow/stream_executor
PiperOrigin-RevId: 470804752
2022-08-29 13:46:20 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Eugene Burmako
2186268ec7 Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.

This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.

C++:
  1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
  2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
  3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
  4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.

Python:
  5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
  6) Python imports don't change.
  7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
  8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:35:23 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Gordian Edenhofer
024ae47e79 Switch from pocketfft to ducc
All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
2022-08-26 13:36:25 +00:00
Rohit Santhanam
82adc6a1d0 [ROCm] Enhance hipsparse to reach parity with cusparse based on commit d37b711dd4. 2022-08-21 22:46:00 +00:00
Tianjian Lu
d37b711dd4 [sparse] Add batch count and batch stride to matrix descriptors.
PiperOrigin-RevId: 468760351
2022-08-19 12:26:17 -07:00
Peter Hawkins
0839958459 Be more selective about which MLIR pieces we build.
Reduces the size of the installed jaxlib by around 20MB.
2022-08-18 22:16:41 +00:00
jax authors
840c96692e Internal change
PiperOrigin-RevId: 468509799
2022-08-18 11:39:07 -07:00
jax authors
d933c8c427 Merge pull request #11978 from hawkinsp:import
PiperOrigin-RevId: 468478116
2022-08-18 09:34:57 -07:00
Deniz Oktay
d5de596d17 Sparse direct solver via QR factorization CUDA implementation.
PiperOrigin-RevId: 468467698
2022-08-18 08:46:25 -07:00
Peter Hawkins
1e241dcf16 Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
2022-08-18 15:22:49 +00:00
Peter Hawkins
3bb0030014 Revert: Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
Reason: Breaks JAX tests.
PiperOrigin-RevId: 468346430
2022-08-17 18:54:29 -07:00
Deniz Oktay
2bc3e39cd9 Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
PiperOrigin-RevId: 468303019
2022-08-17 15:10:27 -07:00
Peter Hawkins
5b0686f9ea Include ABI tag in jaxlib wheels.
Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl

This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl

i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
2022-08-17 15:15:46 +00:00
Mehdi Amini
ae6e0e0950 Move tensorflow/core/platform/{default, google, windows} to tensorflow/tsl/platform/...
PiperOrigin-RevId: 468025286
2022-08-16 14:33:22 -07:00
Yash Katariya
8a1b4785de Use the same jaxlib package name for nightlies. The __version__ will still contain the dev version (with datetime string in it).
PiperOrigin-RevId: 466534455
2022-08-09 18:53:36 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Yash Katariya
f0b6478b3e Plumb env through jax_test.
PiperOrigin-RevId: 465473378
2022-08-04 21:05:28 -07:00
Tianjian Lu
07da502323 [sparse] Enable batch mode of COO matmat from cusparse kernels.
PiperOrigin-RevId: 465405490
2022-08-04 14:30:02 -07:00
Mehdi Amini
5a6cb438e8 Move MHLO to XLA
As part of the OpenXLA project, we're splitting XLA outside of TensorFlow.
MHLO belongs to OpenXLA and we're relocating it nested under XLA to allow the
split. Some further directory layout change will likely happen over time.

PiperOrigin-RevId: 464126676
2022-07-29 11:54:51 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
Parker Schuh
d8f0099f68 _mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Benjamin Kramer
2c72858928 Integrate LLVM at llvm/llvm-project@8aff88fd3a
Updates LLVM usage to match
[8aff88fd3a5f](https://github.com/llvm/llvm-project/commit/8aff88fd3a5f)

PiperOrigin-RevId: 461889195
2022-07-19 08:31:24 -07:00
Tianjian Lu
b421e24bb0 [sparse] Update _validate_coo_mhlo in gpu_sparse.
PiperOrigin-RevId: 461111317
2022-07-14 20:35:09 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
a48f4e116e Change Bazel test rules to generate per-backend test suites. 2022-07-08 14:19:05 +00:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
Peter Hawkins
02534bdff1 Move MLIR dependencies onto //jaxlib rule instead of wheel build rule.
Change in preparation for allowing testing with Bazel.

PiperOrigin-RevId: 458460128
2022-07-01 06:54:31 -07:00