Imported from GitHub PR https://github.com/google/jax/pull/22404
* only check bias batch/num_head sharding spec if present. Both dims could be broadcast.
* dbias calculation is incorrect in spmd and all_reduce is required.
Copybara import of the project:
--
cb81b80626bcf17db875bad5526cd2f24c989049 by cjkkkk <ske@nvidia.com>:
fix sharding
Merging this change closes#22404
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22404 from Cjkkkk:fix_bias_sharding_dbias_all_reduce cb81b80626bcf17db875bad5526cd2f24c989049
PiperOrigin-RevId: 653335832
This currently causes incorrect behaviour for jax.nn.dot_product_attention: it should raise an error rather than failing with an assert.
PiperOrigin-RevId: 650621750
Imported from GitHub PR https://github.com/google/jax/pull/21371
Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.
This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.
cc. @nluehr @Cjkkkk @cliffwoolley
Copybara import of the project:
--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:
Add new SDPA API to jax.nn
Merging this change closes#21371
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
Imported from GitHub PR https://github.com/google/jax/pull/22078
* cuDNN SDPA does not support mask input any more, therefore we combine the bias and mask manually to align with public SDPA API design.
Copybara import of the project:
--
f99c152cb968205e205b6d7de867f09b3b275541 by cjkkkk <ske@nvidia.com>:
combine mask with bias
Merging this change closes#22078
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22078 from Cjkkkk:sdpa_combine_mask_bias f99c152cb968205e205b6d7de867f09b3b275541
PiperOrigin-RevId: 650059019