352 Commits

Author SHA1 Message Date
Jake VanderPlas
13a7034e6a Internal change
PiperOrigin-RevId: 474331907
2022-09-14 10:39:38 -07:00
Tianjian Lu
3243e23aa5 [sparse] Lower batch-mode bcoo_dot_genernal to cusparseSpMM.
PiperOrigin-RevId: 473777597
2022-09-12 10:09:41 -07:00
Skye Wanderman-Milne
031f0b1a10 Add missing Google-internal option to jax_test 2022-09-07 18:31:02 -07:00
jax authors
4847c0929e Init local variables
It's UB to use uninitialized values as arguments.

PiperOrigin-RevId: 471084634
2022-08-30 14:04:25 -07:00
David Dunleavy
a8aa774a57 Use tensorflow/compiler/xla/stream_executor instead of tensorflow/stream_executor
PiperOrigin-RevId: 470804752
2022-08-29 13:46:20 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Eugene Burmako
2186268ec7 Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.

This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.

C++:
  1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
  2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
  3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
  4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.

Python:
  5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
  6) Python imports don't change.
  7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
  8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:35:23 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Gordian Edenhofer
024ae47e79 Switch from pocketfft to ducc
All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
2022-08-26 13:36:25 +00:00
Rohit Santhanam
82adc6a1d0 [ROCm] Enhance hipsparse to reach parity with cusparse based on commit d37b711dd4. 2022-08-21 22:46:00 +00:00
Tianjian Lu
d37b711dd4 [sparse] Add batch count and batch stride to matrix descriptors.
PiperOrigin-RevId: 468760351
2022-08-19 12:26:17 -07:00
Peter Hawkins
0839958459 Be more selective about which MLIR pieces we build.
Reduces the size of the installed jaxlib by around 20MB.
2022-08-18 22:16:41 +00:00
jax authors
840c96692e Internal change
PiperOrigin-RevId: 468509799
2022-08-18 11:39:07 -07:00
jax authors
d933c8c427 Merge pull request #11978 from hawkinsp:import
PiperOrigin-RevId: 468478116
2022-08-18 09:34:57 -07:00
Deniz Oktay
d5de596d17 Sparse direct solver via QR factorization CUDA implementation.
PiperOrigin-RevId: 468467698
2022-08-18 08:46:25 -07:00
Peter Hawkins
1e241dcf16 Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
2022-08-18 15:22:49 +00:00
Peter Hawkins
3bb0030014 Revert: Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
Reason: Breaks JAX tests.
PiperOrigin-RevId: 468346430
2022-08-17 18:54:29 -07:00
Deniz Oktay
2bc3e39cd9 Sparse direct solver using QR factorization from cuSOLVER. This is the jaxlib implementation. We will want to combine this with the sparse libraries already existing in JAX.
PiperOrigin-RevId: 468303019
2022-08-17 15:10:27 -07:00
Peter Hawkins
5b0686f9ea Include ABI tag in jaxlib wheels.
Currently JAX wheels end up with names like:
jaxlib-0.3.15-cp39-none-manylinux2014_x86_64.whl

This PR changes the wheel names to:
jaxlib-0.3.15-cp39-cp39-manylinux2014_x86_64.whl

i.e., we include the CPython ABI tag. This simply reflects the status
quo in the wheel name, and does not change what jaxlib needs.
2022-08-17 15:15:46 +00:00
Mehdi Amini
ae6e0e0950 Move tensorflow/core/platform/{default, google, windows} to tensorflow/tsl/platform/...
PiperOrigin-RevId: 468025286
2022-08-16 14:33:22 -07:00
Yash Katariya
8a1b4785de Use the same jaxlib package name for nightlies. The __version__ will still contain the dev version (with datetime string in it).
PiperOrigin-RevId: 466534455
2022-08-09 18:53:36 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Peter Hawkins
b865111996 Refactor BUILD files to avoid individually naming Python dependencies.
Add a parametric py_deps() macro for adding Python package dependencies for Bazel rules.

Fix build failure with dangling matplotlib reference.

PiperOrigin-RevId: 465562141
2022-08-05 07:49:20 -07:00
Yash Katariya
f0b6478b3e Plumb env through jax_test.
PiperOrigin-RevId: 465473378
2022-08-04 21:05:28 -07:00
Tianjian Lu
07da502323 [sparse] Enable batch mode of COO matmat from cusparse kernels.
PiperOrigin-RevId: 465405490
2022-08-04 14:30:02 -07:00
Mehdi Amini
5a6cb438e8 Move MHLO to XLA
As part of the OpenXLA project, we're splitting XLA outside of TensorFlow.
MHLO belongs to OpenXLA and we're relocating it nested under XLA to allow the
split. Some further directory layout change will likely happen over time.

PiperOrigin-RevId: 464126676
2022-07-29 11:54:51 -07:00
Yash Katariya
9a5af235da Delete sharded_jit
PiperOrigin-RevId: 464081692
2022-07-29 08:19:52 -07:00
Parker Schuh
d8f0099f68 _mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Benjamin Kramer
2c72858928 Integrate LLVM at llvm/llvm-project@8aff88fd3a
Updates LLVM usage to match
[8aff88fd3a5f](https://github.com/llvm/llvm-project/commit/8aff88fd3a5f)

PiperOrigin-RevId: 461889195
2022-07-19 08:31:24 -07:00
Tianjian Lu
b421e24bb0 [sparse] Update _validate_coo_mhlo in gpu_sparse.
PiperOrigin-RevId: 461111317
2022-07-14 20:35:09 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
a48f4e116e Change Bazel test rules to generate per-backend test suites. 2022-07-08 14:19:05 +00:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
Peter Hawkins
02534bdff1 Move MLIR dependencies onto //jaxlib rule instead of wheel build rule.
Change in preparation for allowing testing with Bazel.

PiperOrigin-RevId: 458460128
2022-07-01 06:54:31 -07:00
jax authors
35f5d0fa06 Merge pull request #11266 from ROCmSoftwarePlatform:rocm_fixes
PiperOrigin-RevId: 458043701
2022-06-29 12:59:58 -07:00
Peter Hawkins
1e171ccd10 Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
2022-06-29 12:51:01 -07:00
jax authors
eb0052bdf2 Merge pull request #11296 from rsuderman:AddMLProgram
PiperOrigin-RevId: 458013593
2022-06-29 10:57:16 -07:00
Rohit Santhanam
080cf47002 [ROCm] Fixes for compilation failures caused by compiler changes in ROCm Tensorflow fork. 2022-06-29 14:34:08 +00:00
Sharad Vikram
fcf65ac64e Bump minimum jaxlib version to 0.3.10 2022-06-28 15:39:21 -07:00
jax authors
10320cbcc2 Merge pull request #11294 from sharadmv:release
PiperOrigin-RevId: 457828893
2022-06-28 15:12:14 -07:00
Sharad Vikram
1daea700f2 Bump JAX/Jaxlib versions 2022-06-28 14:36:47 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Robert Suderman
499a4e733c Expose ml_program dialect for MLIR builder
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
2022-06-28 20:29:41 +00:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
Peter Hawkins
a560a29e12 Increase the minimum scipy version to 1.5.
We don't have a formal support policy for scipy versions, but 1.5 dates from around the same date as the oldest supported NumPy release NEP-29 would have us support (1.20).
2022-06-24 15:07:09 -04:00