24 Commits

Author SHA1 Message Date
Shanbin Ke
ab9fc2d839 PR #22404: [cuDNN SDPA] fix bias sharding check and dbias calculation
Imported from GitHub PR https://github.com/google/jax/pull/22404

* only check bias batch/num_head sharding spec if present. Both dims could be broadcast.
* dbias calculation is incorrect in spmd and all_reduce is required.
Copybara import of the project:

--
cb81b80626bcf17db875bad5526cd2f24c989049 by cjkkkk <ske@nvidia.com>:

fix sharding

Merging this change closes #22404

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22404 from Cjkkkk:fix_bias_sharding_dbias_all_reduce cb81b80626bcf17db875bad5526cd2f24c989049
PiperOrigin-RevId: 653335832
2024-07-17 13:07:15 -07:00
Sebastian Bodenstein
c9534b315e Raise NotImplementedError instead of assert for unsupported Q dtype in fused attention.
This currently causes incorrect behaviour for jax.nn.dot_product_attention: it should raise an error rather than failing with an assert.

PiperOrigin-RevId: 650621750
2024-07-09 07:37:13 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Jake VanderPlas
f851099649 Fix mypy error 2024-07-07 19:45:50 -07:00
Shanbin Ke
ceea8dc3fe PR #22078: [cuDNN SDPA] combine mask with bias to support public SDPA API
Imported from GitHub PR https://github.com/google/jax/pull/22078

* cuDNN SDPA does not support mask input any more, therefore we combine the bias and mask manually to align with public SDPA API design.
Copybara import of the project:

--
f99c152cb968205e205b6d7de867f09b3b275541 by cjkkkk <ske@nvidia.com>:

combine mask with bias

Merging this change closes #22078

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22078 from Cjkkkk:sdpa_combine_mask_bias f99c152cb968205e205b6d7de867f09b3b275541
PiperOrigin-RevId: 650059019
2024-07-07 14:42:45 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
cjkkkk
f9586737dc init 2024-05-24 22:26:20 +00:00
Jake VanderPlas
329ab036ee CI: fix mypy error 2024-05-20 13:23:15 -07:00
Shanbin Ke
06d2e489eb Copybara import of the project:
--
f625317cc80639178882316df6f8775294adc6b7 by cjkkkk <ske@nvidia.com>:

init

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317cc80639178882316df6f8775294adc6b7
PiperOrigin-RevId: 635518631
2024-05-20 11:31:15 -07:00
kaixih
0489eee632 Support BNTH input formats 2024-04-03 20:48:37 +00:00
Cjkkkk
204ee7ff0b add is_training && fix seqlen/head_dim checks 2024-03-14 14:34:40 -07:00
Jake VanderPlas
85f205bdc7 typing: fix incorrect tuple annotations 2024-02-26 10:53:19 -08:00
Benjamin Chetioui
5da43a4c55 [XLA:GPU] Fix misspelled cuDNN custom call targets.
PiperOrigin-RevId: 609024769
2024-02-21 09:35:03 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Cjkkkk
916e53a8a2 add keyword-only argument & fix scale issue 2024-02-09 09:05:09 -08:00
Cjkkkk
59307e9625 add jax.cudnn & add check for bias/mask sharding 2024-02-09 09:05:09 -08:00
Cjkkkk
49f1537f98 rename tests with more descriptive name & Unify SDPA API 2024-02-09 09:05:09 -08:00
Cjkkkk
40eb11bc79 replace pjit with jit and only allow shardings on batch/head dim 2024-02-09 09:05:08 -08:00
Cjkkkk
5708fb955b address some format issues 2024-02-09 09:05:08 -08:00
Cjkkkk
6957d26dd3 add newline 2024-02-09 09:05:08 -08:00
Cjkkkk
145cbb55d8 add license 2024-02-09 09:05:08 -08:00
Cjkkkk
c49a770a6b add __init__.py 2024-02-09 09:05:08 -08:00
Cjkkkk
2d346149de fix test 2024-02-09 09:05:08 -08:00
Cjkkkk
9b8a100039 add unit test and move to _src/cudnn dir 2024-02-09 09:05:08 -08:00