357 Commits

Author SHA1 Message Date
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Yash Katariya
b38e85b3a4 Package utils.cc properly in jaxlib so that if jaxlib nightly is installed and then used, jaxlib_utils can be accessed.
PiperOrigin-RevId: 523374835
2023-04-11 05:38:30 -07:00
Rahul Batra
13e45c8953 [ROCm]: Run pmap test on specific number of GPUs 2023-03-30 18:34:47 +00:00
jax authors
6715736583 Merge pull request #15205 from yhtang:editable-jaxlib-build
PiperOrigin-RevId: 519704474
2023-03-27 06:33:31 -07:00
Yu-Hang 'Maxin' Tang
caaa0a2669 add build option to create editable jaxlib
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
2023-03-24 21:25:26 +00:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
jax authors
edff87eb07 Merge pull request #13613 from ROCmSoftwarePlatform:rocm_rt_build
PiperOrigin-RevId: 510440289
2023-02-17 08:40:28 -08:00
Chao
0dde7a0fb1
Update Dockerfile.ms
update to ROCm5.4
2023-02-17 14:33:33 +00:00
Stella Laurenzo
c1e13bdf3f A few developer workflow enhancements for working with jaxlib.
It seems to me that jaxlib development must be mostly happening on CI, because some basics are pretty essential. Here are a few things I've been typing/carrying for a while in my flow:

* Add .bazelrc.user to .gitignore so it doesn't accidentally get checked in.
* Add configs for 'debug_symbols' and 'debug' that make some things minimally workable under a debugger (or to get backtraces, etc).
* Add `--force-reinstall` to the copy/paste command to update a built jaxlib wheel (without this, if you are iterating, it fairly quietly does nothing).
2023-02-10 21:03:21 -08:00
Rahul Batra
023226e181 [ROCm]: Move dockerfile to ROCm5.4 2023-02-09 20:08:35 +00:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Chao Chen
a2c9fc02e4 jax-rocm runtime/ci dockerfile multistages 2022-12-12 07:45:12 -08:00
Jake VanderPlas
e7f53479e2 Some cleanups related to dropping Python 3.7 2022-11-29 15:54:49 -08:00
Qiao Zhang
4d1c4bc761 Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 491445515
2022-11-28 14:31:48 -08:00
Peter Hawkins
41c90a838e Add missing stablehlo dialect files to jaxlib build.
Unbreaks the build.
2022-11-09 13:37:49 +00:00
Yash Katariya
cf6b5097d0 Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Rohit Santhanam
680096cf04 [ROCm] Fix ROCm dockerfile to remove 2020-resolver opt in.
The 2020-resolver opt in has been removed in pip 22.3 since it has now become the default.
2022-10-19 18:11:04 +00:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
Chao Chen
8c13142ae6 fixed build instructions typo 2022-10-11 07:30:49 -07:00
Jason Furmanek
34f6646050 Add default setting for TENSORFLOW_ROCM_COMMIT 2022-10-07 19:57:53 +00:00
jax authors
1f3048bcf2 Merge pull request #12573 from ROCmSoftwarePlatform:rocm-ci-update
PiperOrigin-RevId: 479369638
2022-10-06 11:33:41 -07:00
Rohit Santhanam
b815ac9d8e [ROCm] Upgrade to ROCm 5.3 and associated enhancements 2022-10-01 04:45:26 -07:00
Jason Furmanek
0783f8982d [ROCM] Add TENSORFLOW_ROCM_COMMIT parameter to ROCM ci build 2022-09-29 12:40:40 +00:00
jax authors
6c47dc51cb Merge pull request #12471 from ROCmSoftwarePlatform:rocm-dockerfile-update
PiperOrigin-RevId: 476387200
2022-09-23 09:16:38 -07:00
jax authors
254dc24a8b Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
2022-09-23 07:29:17 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jason Furmanek
9a11b61829 [ROCM] Update Dockerfil.rocm to Ubuntu20 2022-09-22 14:29:30 -04:00
Sharad Vikram
2d8b228706 Add function to visualize Shardings 2022-09-19 13:27:08 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Eugene Burmako
2186268ec7 Migrate from MLIR-HLO's CHLO to StableHLO's CHLO
Unlike StableHLO which is meant to coexist with MHLO, StableHLO's CHLO is meant to replace MLIR-HLO's CHLO.

This change is the final step towards enabling adoption of StableHLO. If we keep two copies of CHLO, then some users won't be able to depend on both MLIR-HLO and StableHLO, and that is a useful possibility to enable both in the short and in the long term.

C++:
  1) C++ dependency changes from `//third_party/tensorflow/compiler/xla/mlir_hlo` (includes CHLO, among other things) to `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:chlo_ops` (in CMake, from `ChloDialect` to `ChloOps`).
  2) .h include changes from `#include "third_party/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"` to `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/ChloOps.h"`.
  3) To register the CHLO dialect in C++, you'll need to depend on `//third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo:register`, include `#include "third_party/tensorflow/compiler/xla/mlir_hlo/stablehlo/stablehlo/dialect/Register.h"` and call `mlir::stablehlo::registerAllDialects(registry)`.
  4) C++ usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.

Python:
  5) Python dependency changes from `//third_party/py/mlir:mhlo_dialect` (includes CHLO, among other things) to `//third_party/py/mlir:chlo_dialect` (in CMake, from `MLIRHLOPythonModules` to `StablehloUnifiedPythonModules`).
  6) Python imports don't change.
  7) To register the CHLO dialect in Python, you'll need to change `chlo.register_chlo_dialect(context)` to `chlo.register_dialect(context)`.
  8) Python usage doesn't change - StableHLO's CHLO is an exact copy of MLIR-HLO's CHLO.
PiperOrigin-RevId: 470265566
2022-08-26 09:35:23 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Gordian Edenhofer
024ae47e79 Switch from pocketfft to ducc
All credit goes to Martin Reinecke <martin@mpa-garching.mpg.de>.
2022-08-26 13:36:25 +00:00
Peter Hawkins
0839958459 Be more selective about which MLIR pieces we build.
Reduces the size of the installed jaxlib by around 20MB.
2022-08-18 22:16:41 +00:00
Jake
21f82c6c0d Use the pjrt plugin device client. 2022-08-17 14:34:07 +01:00
Peter Hawkins
03876bd702 build.py fixes.
* Add aarch64 as a known target_cpu value.
* Only pass --bazel_options to build actions since they can make "bazel
  shutdown" fail.
* Pass the bazel startup options to "bazel shutdown".

Issue https://github.com/google/jax/issues/7097
Fixes https://github.com/google/jax/issues/7639
2022-08-16 15:47:15 +00:00
Jake VanderPlas
f7731c8a29 Tests: require pillow>=9.1.0 & remove backward compatibility 2022-08-12 13:34:56 -07:00
jax authors
e81578a9fa Merge pull request #11780 from ROCmSoftwarePlatform:rocm_update_dockerfile
PiperOrigin-RevId: 466756858
2022-08-10 12:19:13 -07:00
Rohit Santhanam
1b3542427e [ROCm] Update Dockerfile.rocm. 2022-08-09 11:09:10 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Vlad Feinberg
269067e3e8 Make LOBPCG test plots compatible with bazel.
bazel test invocations would previously not work, because the lobpcg_test did not include the appropriate flag parsing and absl test invocations when run as a script. This change fixes that, and in addition shards tests and removes needless and redundant slow tests with larger matrix sizes to make the tests finish in a smaller amount of time. Now, generated pngs with debug information are properly reported via the undeclared outputs directory when the environment variable to emit them, LOBPCG_EMIT_DEBUG_PLOTS, is set to a non-falsy value.

PiperOrigin-RevId: 465465731
2022-08-04 20:05:53 -07:00
Jake VanderPlas
c4169a0c76 make tests compatible with recent pillow versions 2022-07-22 13:09:52 -07:00
Parker Schuh
d8f0099f68 _mlirTransforms merged into _mlirRegisterEverything.
PiperOrigin-RevId: 462233907
2022-07-20 14:43:27 -07:00
Jake VanderPlas
00d8ce6c4a Populate long_description for jax & jaxlib 2022-07-13 14:03:32 -07:00
Peter Hawkins
4443705e0f Add script for parallel accelerator testing under Bazel. 2022-07-06 10:58:04 -04:00