Imported from GitHub PR https://github.com/google/jax/pull/21371
Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.
This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.
cc. @nluehr @Cjkkkk @cliffwoolley
Copybara import of the project:
--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:
Add new SDPA API to jax.nn
Merging this change closes#21371
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
This is necessary to avoid a circular dependency
jax -> fused_attention_stablehlo -> experimental -> jax
in google/jax#21371.
PiperOrigin-RevId: 650201550
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.
memory space is exposed via JAX memories API so it doesn't have to be in the layout API.
Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.
Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.
PiperOrigin-RevId: 647487510
This is part of the work to move the export APIs out
of jax.experimental. For now, the way to use this
implementation is still through `jax.experimental.export`.
Had to add a few "#type ignore" to the _export.py because
previously the file was exempt from internal pytype.
Will try to fix these in a later PR.
PiperOrigin-RevId: 641688200
As of this change, `XLACompatibleSharding` is an alias of `jax.sharding.Sharding` but it will be deprecated in a follow up change.
Why do this?
* All shardings JAX has are XLA Compatible. The reason why `Sharding` was created was to allow non-xla shardings but that's not happened in the past 2 years. So let's simplify!
* Having these 2 types makes things very confusing. One example is:
* `jax.jit` only accepts XLACompatibleShardings.
* `jax.device_put` accepts `jax.sharding.Sharding` but if you use `device_put` inside `jax.jit` with a memory_kind then you can only pass `XLACompatibleSharding`. This is contradicting and confusing and we can simplify.
PiperOrigin-RevId: 640527070
This lets us avoid bundling a whole another copy of LLVM with JAX packages
and so we can finally start building Mosaic GPU by default.
PiperOrigin-RevId: 638569750
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.
It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.
All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
The primitive is currently only support in Pallas GPU when lowering to Triton.
See documentation inline for the Triton-specific restrictions.
PiperOrigin-RevId: 636120214
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.
`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.
PiperOrigin-RevId: 634909918
Almost nothing is supported, including
* PyTree inputs/outputs
* indexers
* non-trivial grids
* block specs
* any primitives beyond the ones added here
* etc etc
PiperOrigin-RevId: 633713366
Note that this adds the minimum of safety net to protect against
non-backwards-compatible changes. We really should have more tests
that cover more of the Triton MLIR.
Also enable serialization of such calls.
PiperOrigin-RevId: 630033989
The host_callbacks APIs are deprecated and will be removed. In order to
help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`)
that when set to `False` will use `io_callback` (and `pure_callback` and
`jax.debug.callback`) to implement the host_callback APIs.
See issue #20385 for more details.
We change the tests to accomodate slightly different results when using
the new callbacks. The tests that use `tap_with_device` and `call_with_device`
are disabled when using the new callbacks.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
We decided to expose a Python alternative again to make it easier for OSS users to see and customize the pipeline. The default is still to run the pipeline from XLA.
The original one was removed in cl/596464480 and cl/597332393.
PiperOrigin-RevId: 617291995
with help from @sharadmv, @yashkatariya, @dougalm, and others
The basic strategy is to apply discharge_state when lowering a jaxpr with state
effects to HLO, and update the dispatch path accordingly. Specifically:
1. in tests only for now, introduce a MutableArray data type;
2. teach jit to abstract it to a Ref(ShapedArray) type, register an input
handler, etc;
3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with
refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing;
4. teach the output side of the dispatch path to drop those outputs.
As an alternative to (3), we could potentially lower away the effects at a
higher level, like in _pjit_lower_cached. They are similar because
_pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation.
I decided to do it in lower_sharding_computation mainly because that's closer
to where we set up aliases, and I wanted to make mutable arrays correspond to
aliased inputs/outputs on the XLA computation.