70 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
6c59d72c75 Bump the minimum jaxlib version to 0.3.15. 2022-09-08 16:43:46 -04:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Peter Hawkins
b63801b4db Fixes for PocketFFT->ducc migration.
* Rename modules from pocketfft to ducc.
* Fix up strides at their generation point rather than where they are
  consumed.
2022-08-26 14:30:03 +00:00
Kuangyuan Chen
12f85c8e23 Introduce class PyShardedBuffer that contains a vector of PjRtBuffer and cache it in GDA.
PiperOrigin-RevId: 470024839
2022-08-25 10:32:39 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Kuangyuan Chen
ed9d7c25fb Add PyShardedToken for sharded execution.
This avoids creating too many python objects, which adds overhead.

PiperOrigin-RevId: 468536900
2022-08-18 13:41:50 -07:00
Sharad Vikram
3ec1b1b987 Check if XLA Executable has execute_with_token before using it
PiperOrigin-RevId: 466470801
2022-08-09 14:34:57 -07:00
jax authors
4abf7cab5a Merge pull request #11672 from hawkinsp:visible
PiperOrigin-RevId: 466050997
2022-08-08 08:38:02 -07:00
Peter Hawkins
c3af67ebcc Add configuration options jax_cuda_visible_devices and jax_rocm_visible_devices.
Example usage:
jax.config.update("jax_cuda_visible_devices", "1,3")
2022-07-29 17:51:49 -04:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
Peter Hawkins
41b015ab0c Remove stale code from jax/_src/lib/__init__.py
Remove inaccurate/stale __all__.
Remove unused alias _xla_extension_version.
2022-07-08 11:09:58 -04:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
jax authors
fb7e39b13e Merge pull request #11390 from hawkinsp:distributed_init
PiperOrigin-RevId: 459518348
2022-07-07 08:23:26 -07:00
Peter Hawkins
bdbdecd458 Refactor distributed GPU device initialization.
Avoid reregistering backend factories; instead simply have the usual
factory function support distributed GPU.
2022-07-07 00:45:19 +00:00
Robert Suderman
45046857f6 Fix ModuleNotFoundError for phawkins only with version 2022-06-28 22:42:45 +00:00
Robert Suderman
64aaeb2da9 Make ml_program import conditional 2022-06-28 20:43:50 +00:00
Robert Suderman
499a4e733c Expose ml_program dialect for MLIR builder
We now have an ml_program dialect that describes global variables
including load and store operations. Expose this dialect to allow
exporting variables and constants.
2022-06-28 20:29:41 +00:00
Shiva Shahrokhi
df8c6263de Change JAX_PLATFORMS to raise an exception when platform initialization fails 2022-06-24 21:54:53 +00:00
jax authors
e42c39ec0b Merge pull request #10910 from hawkinsp:macarm
PiperOrigin-RevId: 452308003
2022-06-01 09:04:52 -07:00
Peter Hawkins
07d112f08a Don't warn when falling back to CPU on Mac.
On Mac, we don't support anything else, so the warning is pointless.
2022-05-31 21:33:42 -04:00
Peter Hawkins
d546503c74 Remove experimental warning from Mac ARM wheels.
Enough JAX developers now have Mac ARM machines that we can be reasonably confident of noticing problems, even if we don't yet have CI coverage.
2022-05-31 21:00:03 -04:00
Tom Hennigan
b308874880 Only register backend factories once when jax_platforms config value is set.
PiperOrigin-RevId: 450005138
2022-05-20 09:46:37 -07:00
Peter Hawkins
dfb32fc047 [JAX] Do not build the GPU or TPU extensions if not enabled in the JAX build.
There's no point building extensions that cannot be used, and they sometimes cause problems (e.g., TPU support on Windows builds).

Fixes https://github.com/google/jax/issues/10687

PiperOrigin-RevId: 449287997
2022-05-17 12:32:03 -07:00
Hyojun Kim
bc5a5e17a5 Add jax_xla_profile_version configuration
A new config named jax_xla_profile_version is added
to support XLA compilation profile.

PiperOrigin-RevId: 449276852
2022-05-17 11:44:58 -07:00
Peter Hawkins
337ec47d13 Fix jax 0.3.11 GPU breakge when used with jaxlib 0.3.10. 2022-05-16 00:24:04 +00:00
Peter Hawkins
562e27d72d Merge remaining CUDA and ROCM Python code.
Completes work started in https://github.com/google/jax/pull/10556

PiperOrigin-RevId: 447005344
2022-05-06 09:35:01 -07:00
Peter Hawkins
4618f9ce03 Consolidate hip_prng and cuda_prng.
The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it.

This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP.

PiperOrigin-RevId: 446761784
2022-05-05 10:55:29 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Peter Hawkins
5c4636c983 [JAX] Validate that platforms passed to MHLO lowering are known to exist.
In at least one instance a user was passing a XLA client object rather than the name of a platform.

PiperOrigin-RevId: 445510282
2022-04-29 14:44:32 -07:00
Aart Bik
c1261ccd27 Adds a wrapper to sparse tensor dialect, as part of an
an initial prototype of an alternate JAX compilation path
that emits the MLIR MHLO/CHLO dialects instead of classic XLA HLO
together with sparse tensor types.

PiperOrigin-RevId: 443438043
2022-04-21 11:48:44 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Peter Hawkins
2a68e7c975 Add libstdc++ workaround for conda users. 2022-04-13 14:38:09 -04:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
Hyeontaek Lim
36df8619d7 Bump minimum jaxlib version to 0.3.2 and remove transfer guard compatibility code 2022-04-11 15:33:27 +00:00
Rohit Santhanam
6c560b14a7 Consolidation of hipsolver/cusolver APIs. 2022-04-07 01:46:43 +00:00
jax authors
8c3385c542 Expose AutoSharding's mesh_shape and mesh_ids options to JAX.
PiperOrigin-RevId: 438874347
2022-04-01 11:47:56 -07:00
Reza Rahimi
8cd02946b5 Fix for hipsparse in ROCm. 2022-03-25 17:41:42 +00:00
jax authors
4848c75b88 Adds use_auto_spmd_partitioning and propagates its value down the stack
PiperOrigin-RevId: 434823543
2022-03-15 12:23:01 -07:00
jax authors
cf9a900d78 Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
jax authors
6c45969fe4 Integrate LLVM at llvm/llvm-project@eb27da7dec
Updates LLVM usage to match
[eb27da7dec67](https://github.com/llvm/llvm-project/commit/eb27da7dec67)

PiperOrigin-RevId: 432199388
2022-03-03 08:24:39 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Hyeontaek Lim
beaa00c460 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 428590081
2022-02-14 13:11:49 -08:00
Peter Hawkins
74506c7dda Rollback of: Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; s...

PiperOrigin-RevId: 427576107
2022-02-09 14:44:45 -08:00
Hyeontaek Lim
b7e1fec250 Implement the JAX transfer guard API
Adds `--jax_transfer_guard` flag and `jax.transfer_guard()` context manager that allows logging or disallowing unintended transfers.

The API distinguishes between two types of transfers:
* explicit transfers: `jax.device_put*()` and `jax.device_get()` calls.
* implicit transfers: Other transfers (e.g., printing a `DeviceArray`).

The transfer guard can take an action based on its guard level:

* "allow": Silently allow all transfers (default; same as the previous behavior).
* "log": Log and allow implicit transfers. Silently allow explicit transfers.
* "disallow": Disallow implicit transfers. Silently allow explicit transfers.
* "log_explicit": Log and allow all transfers.
* "disallow_explicit": Disallow all transfers.

The API also allows fine-control the transfer guard level of individual transfer directions. Their flag and context manager names are suffixed with the transfer direction:

* "host_to_device": Converting a Python value into a `DeviceBuffer`.
* "device_to_device": Copying a `DeviceBuffer` to a different device.
* "device_to_host": Fetching the value of a `DeviceBuffer`.

Example:
```
x = jnp.array(1)
y = jnp.array(2)
z = jnp.array(3)

print(x)  # No error
with jax.transfer_guard("disallow"):
  print(x)  # No error; x is already fetched
  print(jax.device_get(y))  # No error
  print(z)  # Error!
```

PiperOrigin-RevId: 427562278
2022-02-09 13:50:25 -08:00
Peter Hawkins
8be057de1f Introduce a new jax/jaxlib versioning scheme.
Adds a design note that describes the scheme and how the jax and jaxlib versions
are related.
2022-02-07 17:59:42 -05:00
jax authors
b5dace5e55 Merge pull request #9331 from pschuh:id-assignment
PiperOrigin-RevId: 424449474
2022-01-26 14:55:21 -08:00
Peter Hawkins
adda0a42f2 [JAX:IREE] Handle ImportError when iree isn't installed.
PiperOrigin-RevId: 424334198
2022-01-26 06:48:28 -08:00
Parker Schuh
c06933e2a2 Update get_compile_options to accept Device objects. 2022-01-25 16:42:26 -08:00