245 Commits

Author SHA1 Message Date
Adam Paszke
8b21614973 [Pallas:MGPU] Add FlashAttention3 as an example
PiperOrigin-RevId: 690977852
2024-10-29 05:21:43 -07:00
Hyeontaek Lim
77797f434d [JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00
Sergei Lebedev
dfa6fcd56b [pallas:mosaic_gpu] Extracted a basic emit_pipeline API from the in kernel pipelining test
PiperOrigin-RevId: 690619853
2024-10-28 08:25:47 -07:00
Sergei Lebedev
5a2128e44b [pallas] Removed deprecated aliases to CostEstimate and run_scoped
PiperOrigin-RevId: 689871787
2024-10-25 12:16:58 -07:00
Sergei Lebedev
06c08bd118 Renamed :pallas_gpu to :pallas_triton
:pallas_gpu is now an umbrella target for Triton and (hopefully soon)
Mosaic GPU backends.

PiperOrigin-RevId: 683145270
2024-10-07 05:44:00 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00
Tom Natan
ed5ba633d4 Reverts 6cf09f8c24c67ff650b95d174501fff3cb59db0d
PiperOrigin-RevId: 682440543
2024-10-04 13:56:27 -07:00
Justin Fu
350afaa7b6 [Pallas] Clean up lowering exceptions.
PiperOrigin-RevId: 681073628
2024-10-01 10:26:40 -07:00
Tom Natan
6cf09f8c24 Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3
PiperOrigin-RevId: 678756266
2024-09-25 10:33:53 -07:00
Tom Natan
eff00cc449 [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See https://github.com/openxla/stablehlo/pull/2259

PiperOrigin-RevId: 678649138
2024-09-25 04:53:11 -07:00
jax authors
9465d427c0 Merge pull request #22302 from yhtang:add-k8s-initialize
PiperOrigin-RevId: 676962862
2024-09-20 14:03:50 -07:00
Yu-Hang Tang
c88c3aecae add k8s cluster environment 2024-09-20 17:26:53 +00:00
Jevin Jiang
839ce9a11d [Pallas TPU] Refactor ref indexers to transforms and support ref bitcast.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```

Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
2024-09-15 17:53:29 -07:00
jax authors
02b7a76768 Add frontend attributes to Jax. This allows Jax users to annotate Jax code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.
PiperOrigin-RevId: 671930431
2024-09-06 16:44:56 -07:00
Yash Katariya
a144eb234b Add compute_on_context_manager to thread local jit state. This is to avoid getting false cache hits
PiperOrigin-RevId: 671507042
2024-09-05 14:16:13 -07:00
Justin Fu
2d74c6aa05 Add TritonCompilerParams for specifying compiler arguments instead of a dict.
PiperOrigin-RevId: 671081069
2024-09-04 13:32:25 -07:00
Yash Katariya
252caebce3 Create jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None) API to make it easier to create a mesh and reduce a ton of boilerplate.
`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.

PiperOrigin-RevId: 670707995
2024-09-03 14:32:03 -07:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Jieying Luo
a3ae5e18d3 Remove build_cuda_plugin_from_source flag which is no longe used.
751b5742fd

PiperOrigin-RevId: 661370449
2024-08-09 12:54:14 -07:00
Jake VanderPlas
48c5fab023 [array api] fix deprecation to support old import pattern 2024-08-01 14:38:59 -07:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
Christos Perivolaropoulos
80a193d5db [pallas] Use the same primitive run_scoped_p for moth mosaic and mosaic_gpu
PiperOrigin-RevId: 655751205
2024-07-24 17:14:30 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Christos Perivolaropoulos
4186824b34 [pallas:mosaic_gpu] Add support for run_scoped
PiperOrigin-RevId: 655338646
2024-07-23 16:13:00 -07:00
Adam Paszke
2ea222544e Add a Promela spec generator for Pallas TPU kernels
This adds a simple model extractor for TPU kernels that generates a Promela spec
outlining the semantics of semaphores and DMAs. The model can be fed into SPIN
and used to e.g. verify the lack of data races or deadlocks. While compelte verification
is very expensive, the tool seems especially good at finding races that are really there.

PiperOrigin-RevId: 653198263
2024-07-17 05:29:22 -07:00
Peter Hawkins
e80d143bed Create a bazel visibility list for experimental_array_api.
PiperOrigin-RevId: 651059464
2024-07-10 10:02:03 -07:00
Sergei Lebedev
65ab63bfd0 Registered a deprecation for the old `pl.BlockSpec` argument order
PiperOrigin-RevId: 650682044
2024-07-09 10:43:08 -07:00
jax authors
0d57c72644 Merge pull request #20174 from coreyjadams:main
PiperOrigin-RevId: 650334673
2024-07-08 12:19:18 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Sergei Lebedev
740945a724 Moved the implementation of `custom_partitioning` into jax/_src
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in google/jax#21371.

PiperOrigin-RevId: 650201550
2024-07-08 04:31:44 -07:00
Ayaka
6c05aa2f32 Clean up 2024-07-04 17:16:32 +04:00
Kyle Gerard Felker
ffc9292365 Squashed commit of the following:
commit 79b8cbf0cb47e32743e0970bc1abeb6a673866a8
Author: Corey Adams <corey.adams@anl.gov>
Date:   Mon Jul 1 14:14:15 2024 -0500

    Fix mypy issues; change variable name to more universally known name

commit 10edc866f568908e536e5c7bd6b59b4e5351781e
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Jun 27 13:25:32 2024 -0500

    Change copyright year to the year this was authored

commit f7086cb44cc98d58a96ae804dcd1787bc31470f7
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Jun 27 13:15:32 2024 -0500

    Update build file to include mpi4py cluster.

commit 6235eb311b9fca2bd81fe1c49456d164b7332753
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:11:48 2024 -0500

    Update distributed.py

    Clean up documentation slightly.

commit ef3a2e220945b2158cf20edeb1e04bbbf8f290ff
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:09:37 2024 -0500

    Update mpi4py_cluster.py

    Further clean up unneeded comments.

commit 6cc07a9a52fc202ecc65c04c513096391c27d02d
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:08:38 2024 -0500

    Update mpi4py_cluster.py

    Remove unneeded commented code.

commit 6701bd1a9d645a0e08d95df1692f43946f0a5eb8
Merge: 5a91ac342 98b87540a
Author: Corey adams <coreyjadams@gmail.com>
Date:   Thu Jun 27 12:07:25 2024 -0500

    Merge branch 'google:main' into main

commit 5a91ac34248afa6f65af3cae66df7d0d122c1d26
Merge: 301bbc67f 6c51234f9
Author: Corey adams <coreyjadams@gmail.com>
Date:   Tue May 28 22:14:08 2024 -0500

    Merge branch 'google:main' into main

commit 301bbc67f938bc30c543cf300cec8a9c75f3eef8
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 11:34:51 2024 -0500

    Add test to verify mpi4py based distributed initialization

commit 19e66949a36bb0edb4cd66b0f170f42b326928ec
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 11:14:40 2024 -0500

    Unify variable naming and fix function argument ordering

commit 72fe093042519e48d9c26b7ede3b266c7a850be6
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue May 28 10:56:25 2024 -0500

    Remove unmerged code

commit 3a96e738a3cdf9b6ed194cb764fa5640a37f6b95
Merge: e4fd97e19 ff3db9b3a
Author: Corey adams <coreyjadams@gmail.com>
Date:   Tue May 28 10:51:41 2024 -0500

    Merge branch 'google:main' into main

commit e4fd97e197211921fb6911054592041015af94ef
Merge: a69729900 72a81e58e
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon May 13 16:01:35 2024 -0500

    Merge branch 'google:main' into main

commit a6972990070d5d2f405d5ede9f82d35c7e6d157a
Merge: 85bcf42bd 1e48adc69
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon May 13 14:21:32 2024 -0500

    Merge branch 'google:main' into main

commit 85bcf42bdd36ad88a3d287c357cd12fde74c7fc0
Merge: af1a4f0a1 06cd05d1d
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 09:09:31 2024 -0500

    Merge branch 'main' of https://github.com/google/jax

commit af1a4f0a12008780e9507d1bdd91e9d11ec35916
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 08:58:33 2024 -0500

    update documentation and elaborate on spec_detect_method variable

commit 01f4709d5ecd4af675f4fb23d02d6a69b927adac
Author: Corey Adams <corey.adams@anl.gov>
Date:   Tue Apr 16 08:45:38 2024 -0500

    Address feedback and comments on PR 20174; fix typo in documentation.

commit 4f22d86e7358c29ed588267a7d91fe55fb94f143
Merge: 900a0372f 71ec6e33c
Author: Corey adams <coreyjadams@gmail.com>
Date:   Mon Mar 11 11:51:30 2024 -0500

    Merge branch 'google:main' into main

commit 900a0372f6147d3c9ab53c95b6a4262e5cfe4457
Author: Corey Adams <corey.adams@anl.gov>
Date:   Mon Mar 11 11:50:48 2024 -0500

    Auto-detect of mpi4py-based configuration is now strictly opt-in.

commit 1992969da6164e456492fe0f9cd4287f6d8f03cf
Author: Corey Adams <corey.adams@anl.gov>
Date:   Thu Mar 7 12:27:43 2024 -0600

    Enable automatic detection of distrbuted variables with any configuration of MPI, as long as mpi4py is available
2024-07-02 13:18:05 -05:00
George Necula
24b42eed5e [export] Clean up BUILD targets for jax.experimental.export
jax.experimental.export is deprecated and will be removed in a future version of JAX.

See migration guide at: https://jax.readthedocs.io/en/latest/export/export.html#migration-guide-from-jax-experimental-export

PiperOrigin-RevId: 647562073
2024-06-27 23:08:48 -07:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
Peter Hawkins
945fde41e4 Update minimum Python version to 3.10. 2024-06-26 13:47:14 -04:00
Justin Fu
8ba8f3bf65 [Pallas] Implement block-invariant sampling.
PiperOrigin-RevId: 646161271
2024-06-24 11:20:39 -07:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
jax authors
ce4a56a137 Merge pull request #21394 from ayaka14732:lru-cache
PiperOrigin-RevId: 642333998
2024-06-11 11:29:18 -07:00
Ayaka
1a3a15c9e3 Implement LRU cache eviction for persistent compilation cache
Co-authored-by: Sergei Lebedev <slebedev@google.com>
2024-06-11 21:48:35 +04:00
Sergei Lebedev
f8473509cf Removed kernel_regeneration_util from Mosaic
It was only used for persisting kernel metadata, and that can be done via
jax.named_scope instead.

PiperOrigin-RevId: 642195336
2024-06-11 02:36:41 -07:00
Justin Fu
9439f63645 [Pallas] Add pallas TPU random key impls and lowering rules for basic prng ops (seed/foldin/bits/unwrap/wrap).
PiperOrigin-RevId: 642085019
2024-06-10 18:08:19 -07:00
Sergei Lebedev
5e7ad600e2 Removed the double re-exporting of Pallas GPU/TPU APIs
jax.experimental.pallas.{gpu,tpu} now import directly from the relevant
jax._src.pallas.{triton,mosaic} submodules.

PiperOrigin-RevId: 641875127
2024-06-10 05:59:09 -07:00
George Necula
14d87d3bf7 [export] Move the export implementation to jax._src.export.
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.

Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.

PiperOrigin-RevId: 641688200
2024-06-09 08:59:50 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Yash Katariya
9e3f290de3 Delete XLACompatibleSharding and replace with jax.sharding.Sharding.
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.

Why do this?

* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!

* Having these 2 types makes things very confusing. One example is:
  * `jax.jit` only accepts XLACompatibleShardings.
  * `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.

PiperOrigin-RevId: 640527070
2024-06-05 08:03:23 -07:00
Sergei Lebedev
40f107e5a5 Moved Pallas GPU ops into pallas/ops/gpu
PiperOrigin-RevId: 640439838
2024-06-05 01:34:46 -07:00
George Necula
39ac584729 [shape_poly] Move to jax._src in preparation for adding to AOT APIs.
The shape polymorphism APIs are still private and are only exposed through `jax.experimental.export` as before.

PiperOrigin-RevId: 640393089
2024-06-04 22:03:24 -07:00
Yash Katariya
1273028018 Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes.
PiperOrigin-RevId: 639920049
2024-06-03 14:52:50 -07:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Adam Paszke
cfe64cd5ce [Mosaic GPU] Integrate the ExecutionEngine with the jaxlib GPU plugin
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.

PiperOrigin-RevId: 638569750
2024-05-30 01:46:23 -07:00