490 Commits

Author SHA1 Message Date
jax authors
8c694fd008 Merge pull request #10752 from cloudhan:fix-windows-copy
PiperOrigin-RevId: 449585033
2022-05-18 14:58:16 -07:00
Cloud Han
49d0c9c891 Refactor build_wheel.py
- hide `r.Rlocation`s
- get rid of platform dependent *.so files copy logic
2022-05-18 20:49:56 +08:00
Cloud Han
d0a9f29db4 Use pyext for all *.so files
Otherwise, `copy_to_jaxlib(r.Rlocation("__main__/jaxlib/cpu_feature_guard.so"))`
results an `AttributeError: 'NoneType' object has no attribute 'endswith'`
It seems that *.so file is not symlinked into sandbox anymore due to
pybind11 rules update.
2022-05-18 19:54:28 +08:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
d56601a896 Include mlir.transforms and mlir.passmanager in jaxlib BUILD.
These increase the binary size of jaxlib only a negligible amount, and allow running passes like canonicalization from Python.
2022-05-16 13:00:22 +00:00
Yash Katariya
46d034baab Add the nightly dev version to __version__ of jaxlib.
PiperOrigin-RevId: 448001375
2022-05-11 08:35:16 -07:00
Peter Hawkins
883cf2b1e9 Refactor custom call building code in jaxlib to use a helper function.
Refactoring only, no functional changes intended.

This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type.

Fixes https://github.com/google/jax/issues/10474

PiperOrigin-RevId: 447076192
2022-05-06 14:51:24 -07:00
Peter Hawkins
08c3c2ec24 Split CUDA and HIP C++ code in jaxlib into separate directories.
PiperOrigin-RevId: 447062506
2022-05-06 13:48:00 -07:00
Peter Hawkins
562e27d72d Merge remaining CUDA and ROCM Python code.
Completes work started in https://github.com/google/jax/pull/10556

PiperOrigin-RevId: 447005344
2022-05-06 09:35:01 -07:00
Peter Hawkins
83dee8f81e Make jaxlib extension libraries Bazel deps of //jaxlib.
Previously we depended on various .so files directly so they were pulled into the jaxlib wheel build, but it seems to work to add the libraries in question to //jaxlib and depend on that in the usual way.

It appears if a py_library() is used as a data-dependency of another rule, Bazel includes any transitive C++ extension deps, and that's what we want.

PiperOrigin-RevId: 446802592
2022-05-05 13:30:06 -07:00
Peter Hawkins
4618f9ce03 Consolidate hip_prng and cuda_prng.
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.

This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.

PiperOrigin-RevId: 446761784
2022-05-05 10:55:29 -07:00
Peter Hawkins
9fb9e12169 Don't include PTX for older GPU generations.
See: https://github.com/tensorflow/tensorflow/pull/55613

For a CUDA build at head with the default compute capabilities, reduces wheel size from 141MB to 112MB.

Don't redundantly specify default compute capabilities in .bazelrc and in
build.py.
2022-05-02 20:27:37 +00:00
Aart Bik
c1261ccd27 Adds a wrapper to sparse tensor dialect, as part of an
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.

PiperOrigin-RevId: 443438043
2022-04-21 11:48:44 -07:00
jax authors
86c8446c00 Merge pull request #10229 from hyeontaek:transfer-guard-remove-compat-code
PiperOrigin-RevId: 441490830
2022-04-13 08:45:28 -07:00
Yash Katariya
6ba9fb699d Upgrade the bazel version to 5.1.1
PiperOrigin-RevId: 441338363
2022-04-12 17:48:09 -07:00
Hyeontaek Lim
36df8619d7 Bump minimum jaxlib version to 0.3.2 and remove transfer guard compatibility code 2022-04-11 15:33:27 +00:00
Yash Katariya
5fdad0ebf5 Roll forward manylinux2014 builds after fixes.
PiperOrigin-RevId: 440589273
2022-04-09 10:19:44 -07:00
Yash Katariya
e9f95fa5fa Make jaxlib builds manylinux2014 compliant.
PiperOrigin-RevId: 440497401
2022-04-08 18:51:46 -07:00
Yash Katariya
506a85b7ff Make jaxlib builds manylinux2014 compliant.
PiperOrigin-RevId: 440476417
2022-04-08 16:21:56 -07:00
Jake VanderPlas
1f300e729b CI: pin pillow<9.1 to prevent deprecation warnings 2022-04-01 09:23:27 -07:00
Yash Katariya
aa5d6b4a58 Fix the breakage by including --experimental_cc_shared_library as done by TF.
PiperOrigin-RevId: 438746867
2022-03-31 23:07:42 -07:00
Rohit Santhanam
190501bc19 Upgrade ROCm build docker to ROCm version 5.1. 2022-03-31 11:56:38 +00:00
jax authors
cf9a900d78 Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
Peter Hawkins
33aa9286ed Replace instances of std with func to fix breakage after the MLIR std dialect was renamed to the func dialect.
PiperOrigin-RevId: 432222826
2022-03-03 10:15:20 -08:00
jax authors
6c45969fe4 Integrate LLVM at llvm/llvm-project@eb27da7dec
Updates LLVM usage to match
[eb27da7dec67](https://github.com/llvm/llvm-project/commit/eb27da7dec67)

PiperOrigin-RevId: 432199388
2022-03-03 08:24:39 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Peter Hawkins
901d459e0d Add cloudpickle as a test requirement.
We have at least one test that tests pickling JAX objects.
2022-02-16 15:04:56 -05:00
Peter Hawkins
2e0cfe8e42 Update the list of default CUDA capabilities used for wheel builds to match build.py. 2022-02-15 09:23:28 -05:00
Hyeontaek Lim
beaa00c460 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 428590081
2022-02-14 13:11:49 -08:00
Peter Hawkins
74506c7dda Rollback of: Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; s...

PiperOrigin-RevId: 427576107
2022-02-09 14:44:45 -08:00
Hyeontaek Lim
b7e1fec250 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 427562278
2022-02-09 13:50:25 -08:00
Peter Hawkins
2388e353da Increase bazel version to 5.0.0 to match TensorFlow
(8871926b0a).
2022-01-28 21:11:02 +00:00
Peter Hawkins
04369a3588 Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00
Peter Hawkins
66823d1392 Include compute capability 8.0 SASS in jaxlib wheels.
Drop compute capability 6.1 to avoid growing the wheel size.

Also fix an unrelated build error due to a gcc warning in boringssl.
2021-12-14 14:27:19 -05:00
Peter Hawkins
ffb7ec1651 Update Bazel to 4.2.1.
Fixes #8573
2021-12-02 09:11:38 -05:00
Cloud Han
317edcdacd fix mlir capi dll building and linking 2021-11-25 00:07:25 +08:00
Peter Hawkins
ce7ae6bd76 Make MLIR bindings build work under Bazel.
Tested on Linux and Mac, but not Windows.
2021-11-12 12:16:32 -05:00
Peter Hawkins
11f6c535ae Add MLIR:Python bindings to jaxlib build.
PiperOrigin-RevId: 407657331
2021-11-04 13:29:58 -07:00
Peter Hawkins
9212d5c83b Print the bazel version from build.py.
Increment the minimum version that build.py checks for to 3.7.2.
2021-11-03 12:12:54 -04:00
Yash Katariya
ac0796048f Move cuda .py files to :gpu_support so that if :gpu_support is not present, then internal jaxlib will act like a CPU jaxlib even if --config=cuda is specified.
PiperOrigin-RevId: 403170945
2021-10-14 13:20:01 -07:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Skye Wanderman-Milne
2fcf3f7270 Remove .[minimum-jaxlib] from test-requirements.txt
This means that jax and its dependencies (e.g. jaxlib) must be
manually installed before running the tests. This is useful for
testing an existing jax install, e.g. a later version of jaxlib, GPU
jaxlib, etc.
2021-09-23 12:24:24 -07:00
brett koonce
9c5009efd5 tweak cuda/rocm targets
Closes #7955.
2021-09-18 14:29:13 -05:00
yashkatariya
d0acd9f343 Add flags to configure the cuda_compute_capability and rocm_amd_targets 2021-09-17 08:43:25 -07:00
Peter Hawkins
94f97b920f Refactor JAX CPU kernels to make them usable from C++.
Most of the work here is porting the LAPACK interface from Cython to plain C++. This is something I wanted to do anyway to make use of C++ templating facilities: the code is noticeably shorter in C++.

This change removes the only use of Cython in JAX. It also removes the need for a build-time dependency on Scipy, which we only needed for Cython cimport reasons.

When using C++, we most likely do not want to fetch LAPACK and BLAS kernels from Python. Therefore we add another option: we define the LAPACK functions we need using weak symbols where supported; the user can then simply link against LAPACK to provide the necessary symbols.

Added a jaxlib:cpu_kernels module to facilitate using the JAX CPU kernels from C++.

PiperOrigin-RevId: 394705605
2021-09-03 10:03:54 -07:00
Peter Hawkins
f004bcb7b8 [JAX] Refactor JAX custom kernels to split kernel implementations from Python bindings.
Some folks want to be able to run JAX-generated HLO computations from C++, and those computations may refer to JAX's custom kernels. This change splits the custom kernels into separate modules that may be used independently of Python.

The general pattern is that each extension now has two parts:
* xyz_kernels.{cc, h} — the C++ parts
* xyz.cc — Python bindings around the C++ parts, including code to build any descriptor objects.

There's also a new (minimally supported) module named "gpu_kernels.cc" which registers JAX's GPU kernels with the XLA C++ custom kernel registry.

PiperOrigin-RevId: 394460343
2021-09-02 07:53:09 -07:00
Yash Katariya
a08a9ad42f Delete the generate_release_indexes file.
PiperOrigin-RevId: 394081682
2021-08-31 14:07:07 -07:00
Jake VanderPlas
a5b6a4e6a9 CI: remove flake8 from test requirements. 2021-08-25 11:07:09 -07:00
dependabot[bot]
9f2863c66b Copybara import of the project:
--
57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f by dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>:

Bump flatbuffers from 1.12 to 2.0

Bumps [flatbuffers](https://github.com/google/flatbuffers) from 1.12 to 2.0.
- [Release notes](https://github.com/google/flatbuffers/releases)
- [Commits](https://github.com/google/flatbuffers/compare/v1.12.0...v2.0.0)

---
updated-dependencies:
- dependency-name: flatbuffers
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7686 from google:dependabot/pip/flatbuffers-2.0 57572d861a8bfe42a3b34b19a6e25a0b7ea4f22f
PiperOrigin-RevId: 392097862
2021-08-20 17:13:26 -07:00