490 Commits

Author SHA1 Message Date
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
Peter Hawkins
41c90a838e Add missing stablehlo dialect files to jaxlib build.
Unbreaks the build.
2022-11-09 13:37:49 +00:00
Yash Katariya
cf6b5097d0 Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Rohit Santhanam
680096cf04 [ROCm] Fix ROCm dockerfile to remove 2020-resolver opt in.
The 2020-resolver opt in has been removed in pip 22.3 since it has now become the default.
2022-10-19 18:11:04 +00:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Chao Chen
8c13142ae6 fixed build instructions typo 2022-10-11 07:30:49 -07:00
Jason Furmanek
34f6646050 Add default setting for TENSORFLOW_ROCM_COMMIT 2022-10-07 19:57:53 +00:00
jax authors
1f3048bcf2 Merge pull request #12573 from ROCmSoftwarePlatform:rocm-ci-update
PiperOrigin-RevId: 479369638
2022-10-06 11:33:41 -07:00
Rohit Santhanam
b815ac9d8e [ROCm] Upgrade to ROCm 5.3 and associated enhancements 2022-10-01 04:45:26 -07:00
Jason Furmanek
0783f8982d [ROCM] Add TENSORFLOW_ROCM_COMMIT parameter to ROCM ci build 2022-09-29 12:40:40 +00:00
jax authors
6c47dc51cb Merge pull request #12471 from ROCmSoftwarePlatform:rocm-dockerfile-update
PiperOrigin-RevId: 476387200
2022-09-23 09:16:38 -07:00
jax authors
254dc24a8b Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jason Furmanek
9a11b61829 [ROCM] Update Dockerfil.rocm to Ubuntu20 2022-09-22 14:29:30 -04:00
Sharad Vikram
2d8b228706 Add function to visualize Shardings 2022-09-19 13:27:08 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Eugene Burmako
2186268ec7 Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.

This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.

C++:
  1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
  2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
  3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
  4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.

Python:
  5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
  6) Python imports don't change.
  7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
  8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:35:23 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Gordian Edenhofer
024ae47e79 Switch from pocketfft to ducc
All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
2022-08-26 13:36:25 +00:00
Peter Hawkins
0839958459 Be more selective about which MLIR pieces we build.
Reduces the size of the installed jaxlib by around 20MB.
2022-08-18 22:16:41 +00:00
Jake
21f82c6c0d Use the pjrt plugin device client. 2022-08-17 14:34:07 +01:00
Peter Hawkins
03876bd702 build.py fixes.
* Add aarch64 as a known target_cpu value.
* Only pass --bazel_options to build actions since they can make "bazel
  shutdown" fail.
* Pass the bazel startup options to "bazel shutdown".

Issue https://github.com/google/jax/issues/7097
Fixes https://github.com/google/jax/issues/7639
2022-08-16 15:47:15 +00:00
Jake VanderPlas
f7731c8a29 Tests: require pillow>=9.1.0 & remove backward compatibility 2022-08-12 13:34:56 -07:00
jax authors
e81578a9fa Merge pull request #11780 from ROCmSoftwarePlatform:rocm_update_dockerfile
PiperOrigin-RevId: 466756858
2022-08-10 12:19:13 -07:00
Rohit Santhanam
1b3542427e [ROCm] Update Dockerfile.rocm. 2022-08-09 11:09:10 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Vlad Feinberg
269067e3e8 Make LOBPCG test plots compatible with bazel.
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.

PiperOrigin-RevId: 465465731
2022-08-04 20:05:53 -07:00
Jake VanderPlas
c4169a0c76 make tests compatible with recent pillow versions 2022-07-22 13:09:52 -07:00
Parker Schuh
d8f0099f68 _mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
4443705e0f Add script for parallel accelerator testing under Bazel. 2022-07-06 10:58:04 -04:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
Peter Hawkins
02534bdff1 Move MLIR dependencies onto //jaxlib rule instead of wheel build rule.
Change in preparation for allowing testing with Bazel.

PiperOrigin-RevId: 458460128
2022-07-01 06:54:31 -07:00
Peter Hawkins
1e171ccd10 Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
2022-06-29 12:51:01 -07:00
jax authors
eb0052bdf2 Merge pull request #11296 from rsuderman:AddMLProgram
PiperOrigin-RevId: 458013593
2022-06-29 10:57:16 -07:00
Rohit Santhanam
721602ef59 Upgrade ROCm build docker to ROCm version 5.2. 2022-06-28 21:43:07 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Robert Suderman
499a4e733c Expose ml_program dialect for MLIR builder
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
2022-06-28 20:29:41 +00:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
jax authors
a90bde2c54 Merge pull request #11231 from hawkinsp:remotetpu
PiperOrigin-RevId: 457005076
2022-06-24 07:13:16 -07:00
Peter Hawkins
22304eeb2e Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
Jake VanderPlas
617df70135 Unpin numpy to ensure most recent version is tested 2022-06-23 12:23:14 -07:00
Yash Katariya
1908da33af Only initialize GPU backends if they are not already initialized
PiperOrigin-RevId: 456664792
2022-06-22 19:39:52 -07:00
Peter Hawkins
69bda69fb6 Bump minimum Mac OS version to 10.14 (Mojave).
It turns out that the support for C++17 is partial in 10.12, and in particular absl::optional and std::optional are not the same thing under 10.12. Increment to 10.14 which is the lowest version that builds successfully with absl::optional == std::optional.

See: 89cdaed655/absl/base/config.h (L528)
Strictly speaking, we could allow 10.13, but not without updating ABSL in the TF repository to incorporate c86347d4ce which fixes the version detection test to permit 10.13 as well.
2022-06-01 20:32:22 -04:00
Nicholas Junge
7f7358cb58 Fix macOS builds by updating the minimum OSX version target
This commit bumps the minimum macOS version target to 10.12, from 10.9
previously. The reason is that a newer LLVM version, which the XLA
compiler depended on, used `std::shared_mutex`, which is only available
starting at macOS 10.12.

The fix here consists of setting the minimum version target flag that
bazel consumes for the build to 10.12.
2022-05-31 15:14:48 +02:00