328 Commits

Author SHA1 Message Date
jax authors
6c47dc51cb Merge pull request #12471 from ROCmSoftwarePlatform:rocm-dockerfile-update
PiperOrigin-RevId: 476387200
2022-09-23 09:16:38 -07:00
jax authors
254dc24a8b Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jason Furmanek
9a11b61829 [ROCM] Update Dockerfil.rocm to Ubuntu20 2022-09-22 14:29:30 -04:00
Sharad Vikram
2d8b228706 Add function to visualize Shardings 2022-09-19 13:27:08 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Eugene Burmako
2186268ec7 Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.

This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.

C++:
  1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
  2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
  3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
  4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.

Python:
  5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
  6) Python imports don't change.
  7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
  8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:35:23 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Gordian Edenhofer
024ae47e79 Switch from pocketfft to ducc
All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
2022-08-26 13:36:25 +00:00
Peter Hawkins
0839958459 Be more selective about which MLIR pieces we build.
Reduces the size of the installed jaxlib by around 20MB.
2022-08-18 22:16:41 +00:00
Jake
21f82c6c0d Use the pjrt plugin device client. 2022-08-17 14:34:07 +01:00
Peter Hawkins
03876bd702 build.py fixes.
* Add aarch64 as a known target_cpu value.
* Only pass --bazel_options to build actions since they can make "bazel
  shutdown" fail.
* Pass the bazel startup options to "bazel shutdown".

Issue https://github.com/google/jax/issues/7097
Fixes https://github.com/google/jax/issues/7639
2022-08-16 15:47:15 +00:00
Jake VanderPlas
f7731c8a29 Tests: require pillow>=9.1.0 & remove backward compatibility 2022-08-12 13:34:56 -07:00
jax authors
e81578a9fa Merge pull request #11780 from ROCmSoftwarePlatform:rocm_update_dockerfile
PiperOrigin-RevId: 466756858
2022-08-10 12:19:13 -07:00
Rohit Santhanam
1b3542427e [ROCm] Update Dockerfile.rocm. 2022-08-09 11:09:10 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Vlad Feinberg
269067e3e8 Make LOBPCG test plots compatible with bazel.
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.

PiperOrigin-RevId: 465465731
2022-08-04 20:05:53 -07:00
Jake VanderPlas
c4169a0c76 make tests compatible with recent pillow versions 2022-07-22 13:09:52 -07:00
Parker Schuh
d8f0099f68 _mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
4443705e0f Add script for parallel accelerator testing under Bazel. 2022-07-06 10:58:04 -04:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Peter Hawkins
7c49864fdf Symlink xla_client and xla_extension into jaxlib rather than copying them into place in the wheel build.
Change in preparation for allowing JAX tests to run under Bazel.

Remove code to patch paths in xla_client.py in the wheel build script; the patch is no longer used.

PiperOrigin-RevId: 458522398
2022-07-01 12:31:42 -07:00
Peter Hawkins
e63765a7a6 Use symlink_files() to add version.py to jaxlib, rather than copying it in as part of the wheel assembly process.
Change in preparation for supporting running JAX tests under Bazel. This change allows the Bazel py_library() to see version.py.

Update symlink_files Bazel macro to a newer version.

PiperOrigin-RevId: 458481396
2022-07-01 09:07:03 -07:00
Peter Hawkins
02534bdff1 Move MLIR dependencies onto //jaxlib rule instead of wheel build rule.
Change in preparation for allowing testing with Bazel.

PiperOrigin-RevId: 458460128
2022-07-01 06:54:31 -07:00
Peter Hawkins
1e171ccd10 Unify jax and jaxlib versions.
Currently jax and jaxlib have separate version numbers in the JAX source
tree. It is tedious and confusing to bump both version numbers.

However, there is a simpler way to think of things: it is the source
tree that is versioned using a single version number, and jax/jaxlib
releases are made using that unified source version number.

PiperOrigin-RevId: 458041752
2022-06-29 12:51:01 -07:00
jax authors
eb0052bdf2 Merge pull request #11296 from rsuderman:AddMLProgram
PiperOrigin-RevId: 458013593
2022-06-29 10:57:16 -07:00
Rohit Santhanam
721602ef59 Upgrade ROCm build docker to ROCm version 5.2. 2022-06-28 21:43:07 -07:00
Peter Hawkins
47f2f091bc Reapply: Drop flatbuffers as a Python dependency of JAX.
The crashes on Mac were, as best we can tell, unrelated to this PR.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457819042
2022-06-28 14:25:14 -07:00
Robert Suderman
499a4e733c Expose ml_program dialect for MLIR builder
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
2022-06-28 20:29:41 +00:00
Peter Hawkins
5b576cb03e Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac.

Original description:
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457559793
2022-06-27 13:56:32 -07:00
Peter Hawkins
efefeac450 Drop flatbuffers as a Python dependency of JAX.
Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API.

Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work.

PiperOrigin-RevId: 457460347
2022-06-27 06:14:07 -07:00
jax authors
a90bde2c54 Merge pull request #11231 from hawkinsp:remotetpu
PiperOrigin-RevId: 457005076
2022-06-24 07:13:16 -07:00
Peter Hawkins
22304eeb2e Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
2022-06-23 21:14:52 +00:00
Jake VanderPlas
617df70135 Unpin numpy to ensure most recent version is tested 2022-06-23 12:23:14 -07:00
Yash Katariya
1908da33af Only initialize GPU backends if they are not already initialized
PiperOrigin-RevId: 456664792
2022-06-22 19:39:52 -07:00
Peter Hawkins
69bda69fb6 Bump minimum Mac OS version to 10.14 (Mojave).
It turns out that the support for C++17 is partial in 10.12, and in particular absl::optional and std::optional are not the same thing under 10.12. Increment to 10.14 which is the lowest version that builds successfully with absl::optional == std::optional.

See: 89cdaed655/absl/base/config.h (L528)
Strictly speaking, we could allow 10.13, but not without updating ABSL in the TF repository to incorporate c86347d4ce which fixes the version detection test to permit 10.13 as well.
2022-06-01 20:32:22 -04:00
Nicholas Junge
7f7358cb58 Fix macOS builds by updating the minimum OSX version target
This commit bumps the minimum macOS version target to 10.12, from 10.9
previously. The reason is that a newer LLVM version, which the XLA
compiler depended on, used `std::shared_mutex`, which is only available
starting at macOS 10.12.

The fix here consists of setting the minimum version target flag that
bazel consumes for the build to 10.12.
2022-05-31 15:14:48 +02:00
jax authors
8c694fd008 Merge pull request #10752 from cloudhan:fix-windows-copy
PiperOrigin-RevId: 449585033
2022-05-18 14:58:16 -07:00
Cloud Han
49d0c9c891 Refactor build_wheel.py
- hide `r.Rlocation`s
- get rid of platform dependent *.so files copy logic
2022-05-18 20:49:56 +08:00
Cloud Han
d0a9f29db4 Use pyext for all *.so files
Otherwise, `copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cpu_feature_guard.so"))`
results an `AttributeError: 'NoneType' object has no attribute 'endswith'`
It seems that *.so file is not symlinked into sandbox anymore due to
pybind11 rules update.
2022-05-18 19:54:28 +08:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
d56601a896 Include mlir.transforms and mlir.passmanager in jaxlib BUILD.
These increase the binary size of jaxlib only a negligible amount, and allow running passes like canonicalization from Python.
2022-05-16 13:00:22 +00:00
Yash Katariya
46d034baab Add the nightly dev version to __version__ of jaxlib.
PiperOrigin-RevId: 448001375
2022-05-11 08:35:16 -07:00
Peter Hawkins
883cf2b1e9 Refactor custom call building code in jaxlib to use a helper function.
Refactoring only, no functional changes intended.

This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.

Fixes https://github.com/google/jax/issues/10474

PiperOrigin-RevId: 447076192
2022-05-06 14:51:24 -07:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00
Peter Hawkins
562e27d72d Merge remaining CUDA and ROCM Python code.
Completes work started in https://github.com/google/jax/pull/10556

PiperOrigin-RevId: 447005344
2022-05-06 09:35:01 -07:00
Peter Hawkins
83dee8f81e Make jaxlib extension libraries Bazel deps of //jaxlib.
Previously we depended on various .so files directly so they were pulled into the jaxlib wheel build, but it seems to work to add the libraries in question to //jaxlib and depend on that in the usual way.

It appears if a py_library() is used as a data-dependency of another rule, Bazel includes any transitive C++ extension deps, and that's what we want.

PiperOrigin-RevId: 446802592
2022-05-05 13:30:06 -07:00
Peter Hawkins
4618f9ce03 Consolidate hip_prng and cuda_prng.
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.

This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.

PiperOrigin-RevId: 446761784
2022-05-05 10:55:29 -07:00
Peter Hawkins
9fb9e12169 Don't include PTX for older GPU generations.
See: https://github.com/tensorflow/tensorflow/pull/55613

For a CUDA build at head with the default compute capabilities, reduces wheel size from 141MB to 112MB.

Don't redundantly specify default compute capabilities in .bazelrc and in
build.py.
2022-05-02 20:27:37 +00:00