17407 Commits

Author SHA1 Message Date
Jieying Luo
91fbf9da26 [PJRT C API] Set up jax xla cuda package.
Add a build wheel, pyproject.toml and setup.py.

The directory structure in jax repo is:
jax/
└── plugins/
     └── cuda/
          ├── __init__.py
          ├── pyproject.toml
          └── setup.py

Installed package structure is:
jax_plugins/
     └── xla_cuda_cu12/
           ├── __init__.py
           └── xla_cuda_plugin.so

The major cuda version will be part of the package name.

The plugin wheel can be built with command:
python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla"

PiperOrigin-RevId: 565187954
2023-09-13 16:03:53 -07:00
Peter Hawkins
a38a152737 Disable lobpcg_test on GPU.
This test is failing in CI, disable it while we debug.

PiperOrigin-RevId: 565180431
2023-09-13 15:44:02 -07:00
Yash Katariya
ebc24c737b Pass sharded inputs to remat offloading tests. When we execute, these inputs will be interesting to validate against the correctness of the compiler passes.
PiperOrigin-RevId: 565180089
2023-09-13 15:43:40 -07:00
Sharad Vikram
44bd916911 [Pallas] Add support for local DMAs on TPU backend
PiperOrigin-RevId: 565179693
2023-09-13 15:32:44 -07:00
jax authors
11c2f167a4 Merge pull request #17594 from jakevdp:dep-prngkey
PiperOrigin-RevId: 565163390
2023-09-13 14:33:56 -07:00
Peter Hawkins
306c60d4c7 Remove references to deprecated "tpu_se" build configuration.
PiperOrigin-RevId: 565156675
2023-09-13 14:10:30 -07:00
Jake VanderPlas
4e6c1b68c7 Deprecate random.KeyArray and random.PRNGKeyArray 2023-09-13 14:05:42 -07:00
Jake VanderPlas
270cc6014c Update internal callers to avoid PRNGKeyArray 2023-09-13 14:05:42 -07:00
Peter Hawkins
729752b32b Disable XLA detailed logging and dumping for small computations.
This significantly reduces the amount of logging from XLA on TPU.

PiperOrigin-RevId: 565148809
2023-09-13 13:45:00 -07:00
Jake VanderPlas
eeb32a7d1f Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped
PiperOrigin-RevId: 565142019
2023-09-13 13:21:46 -07:00
Jake VanderPlas
22ff7bd19a Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct
These have been deprecated in JAX following similar deprecations in numpy v1.25.0

PiperOrigin-RevId: 565122288
2023-09-13 12:07:36 -07:00
Yash Katariya
8340149336 Check if the input which is donated is actually deleted along with the AOT check.
PiperOrigin-RevId: 565098239
2023-09-13 10:50:16 -07:00
Roy Frostig
6abefa1977 fast dispatch for functions over typed PRNG key arrays
Before this change, JAX could dispatch compiled functions over new-style (typed)
RNG key arrays, but it would always do so off of the fast (C++-based) dispatch
path. In other words, switching from old-style `uint32` RNG keys to new-style
keys would regress dispatch times. With this change, dispatch happens on the
fast path again and performance regressions ought to be minimal.

We currently maintain only one pytree registry, for all registered pytree node
types. We want RNG key arrays to also be treated as pytree leaves everywhere
*except* during dispatch. In other words: we want operations on (typed) RNG key
arrays to appear in Jaxpr, but we want to unravel those arrays into their
underlying `uint32` arrays only during dispatch.

To do this, we add a new internal pytree registry that dispatch respects
uniquely. This registry includes all items in the default registry, but also the
RNG key array type.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 565077758
2023-09-13 09:43:58 -07:00
jax authors
84951288bd Merge pull request #17584 from cabbagepatchman:cabbagepatchman-fix-sp
PiperOrigin-RevId: 565066492
2023-09-13 09:03:13 -07:00
jax authors
5a15ba90db Update XLA dependency to use revision
551a620daf.

PiperOrigin-RevId: 564978874
2023-09-13 02:10:36 -07:00
Adam Paszke
4ded017121 [Mosaic] Fix an incorrectly implemented heuristic for selecting tile sizes
It did not properly account for the minimum tile size requirements on TPUv2 and v3.

PiperOrigin-RevId: 564977666
2023-09-13 02:00:40 -07:00
Peter Hawkins
d7a6eed9e1 Update jax2tf/tests/sharding_test.py for TPU runtime changes.
Removes support for an older runtime (StreamExecutor) on TPU.

PiperOrigin-RevId: 564927177
2023-09-12 22:08:26 -07:00
Yash Katariya
c41d271175 Add memories support to remat.
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.

Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.

A very basic offloadable policy can look like this:

```
def policy(prim, *avals, **params):
  return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
2023-09-12 20:50:05 -07:00
cabbagepatchman
125343b7f9
Update README.md 2023-09-12 20:39:06 -07:00
jax authors
7c0abb1c85 Merge pull request #17564 from andportnoy:aportnoy/scipy-spatial-test-increase-tolerance
PiperOrigin-RevId: 564891104
2023-09-12 18:21:31 -07:00
Jevin Jiang
d4b564a263 [Mosaic] Support relayout from (1,128) to (8,128) when dst.offset is (0, 0).
PiperOrigin-RevId: 564882618
2023-09-12 17:35:09 -07:00
jax authors
d3950b93cb Merge pull request #17568 from zhenying-liu:testdotgeneral
PiperOrigin-RevId: 564871912
2023-09-12 16:46:39 -07:00
Jane Liu
90da4f153a Fix an A100 nightly unit test failure on testDotGeneral() by replacing TF32 with float32 2023-09-12 16:31:22 -07:00
Jake VanderPlas
56791eb9ec lax_test: adjust TPU tolerance for igamma & friends
PiperOrigin-RevId: 564859109
2023-09-12 15:59:41 -07:00
Benjamin Kramer
a26125c49e Integrate LLVM at llvm/llvm-project@c1796be93f
Updates LLVM usage to match
[c1796be93fe5](https://github.com/llvm/llvm-project/commit/c1796be93fe5)

PiperOrigin-RevId: 564842806
2023-09-12 15:00:09 -07:00
Jevin Jiang
801cbef011 [Mosaic] Use strided load to load one entire row more efficiently.
PiperOrigin-RevId: 564831610
2023-09-12 14:19:35 -07:00
jax authors
c617bcb515 Merge pull request #17566 from jakevdp:prngarray-type
PiperOrigin-RevId: 564828726
2023-09-12 14:09:02 -07:00
Jevin Jiang
9d8642122a [Mosaic] Use strided store to store one row.
PiperOrigin-RevId: 564821813
2023-09-12 13:56:58 -07:00
Jake VanderPlas
ea5f126e85 [custom prng] make PRNGKeyArray a subclass of jax.Array 2023-09-12 13:48:12 -07:00
jax authors
b20b93e9c9 Merge pull request #17565 from jakevdp:fix-array-type
PiperOrigin-RevId: 564821273
2023-09-12 13:46:16 -07:00
Jake VanderPlas
d44b0389dd [typing] fix a few array type declarations 2023-09-12 13:21:48 -07:00
Andrey Portnoy
34ea2b2e8a Increase comparison tolerance in SciPy spatial RotationMean subtest
Previous value leads to failures on A100 runners in
github.com/NVIDIA/JAX-Toolbox CI:
https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014

The suspected reason is the use of TF32 math for matmuls: decorating the
function with @jax.default_matmul_precision("float32") allows the test to pass.
We thought it's better to loosen the tolerance but preserve the original
execution mode.

The fully qualified test case is
tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0
2023-09-12 16:12:14 -04:00
jax authors
d0df18a76b Merge pull request #17562 from jakevdp:version-import
PiperOrigin-RevId: 564807097
2023-09-12 12:57:29 -07:00
Jake VanderPlas
1800015884 Import jax.version first 2023-09-12 12:27:20 -07:00
jax authors
63e51fe1e6 Merge pull request #17524 from jakevdp:pref-eltype-test
PiperOrigin-RevId: 564792174
2023-09-12 12:01:21 -07:00
Adam Paszke
dbb0e8f214 [Mosaic] Add a pass for instantiating memory spaces
PiperOrigin-RevId: 564723473
2023-09-12 08:05:26 -07:00
Peter Hawkins
7dddb507e9 [XLA:Python] Remove the use_tfrt flag from make_cpu_client().
use_tfrt=True has been the default for over a year, and the flag currently does nothing.
PiperOrigin-RevId: 564712316
2023-09-12 07:15:04 -07:00
jax authors
462c0bd30d Update XLA dependency to use revision
cd7379d9af.

PiperOrigin-RevId: 564660543
2023-09-12 02:59:25 -07:00
Yash Katariya
2a7b8e6278 Add gpu_common_utils to build_wheel to fix the gpu wheels build
PiperOrigin-RevId: 564562958
2023-09-11 18:40:55 -07:00
Yash Katariya
76a5dc3cac Move memories_test.py to JAX
PiperOrigin-RevId: 564551723
2023-09-11 17:41:55 -07:00
Ruoxin Sang
3e06dc8b77 Update jax_spmd_mode flag docstring and remove unused allow_pjit option.
PiperOrigin-RevId: 564543943
2023-09-11 17:08:35 -07:00
Qiao Zhang
d4adf0095f Add default jvp and transpose rule for jax.lax.reduce_precision.
PiperOrigin-RevId: 564536160
2023-09-11 16:35:44 -07:00
John QiangZhang
997b35e1d9 Improve the gpu lowering error message if users forget link the gpu library.
PiperOrigin-RevId: 564530960
2023-09-11 16:14:18 -07:00
jax authors
6c3b42d33c Add flags to exclude from cache-key generation.
Some flags do not affect the compilation output. These should
not be part of the cache key, otherwise changing them will change
the key causing an unnecessary cache miss.

Synchronize the exclusions between the command-line flags and
DebugOptions. Add if-this-then-that lint checks to keep them
in sync.

PiperOrigin-RevId: 564474189
2023-09-11 12:57:45 -07:00
jax authors
05d2432e9e Merge pull request #17527 from ROCmSoftwarePlatform:rocm_build_updates_1
PiperOrigin-RevId: 564472739
2023-09-11 12:48:10 -07:00
jax authors
23e4f0b471 Hash serialized topology description for new cache key generation.
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.

Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.

Testing: revised unit test.
PiperOrigin-RevId: 564461564
2023-09-11 12:08:26 -07:00
Yash Katariya
a36598b2a7 Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
2023-09-11 11:55:09 -07:00
jax authors
bfc12bdda9 Update XLA dependency to use revision
96beae40a1.

PiperOrigin-RevId: 564327668
2023-09-11 03:36:30 -07:00
jax authors
292deef6fd Update XLA dependency to use revision
c227585959.

PiperOrigin-RevId: 563982388
2023-09-09 04:01:14 -07:00
Rahul Batra
4091ac646c [ROCm]: Fix duplicate deps include 2023-09-08 22:56:59 +00:00