68 Commits

Author SHA1 Message Date
jax authors
74f1d887eb Merge pull request #28018 from Cjkkkk:disable_packed_layout_at_ampere
PiperOrigin-RevId: 748349568
2025-04-16 10:54:25 -07:00
cjkkkk
760d0e0e97 disable packed layout test on old arch prior to Hopper 2025-04-14 20:33:30 +00:00
jax authors
ac285a138b Merge pull request #27685 from Cjkkkk:return_cudnn_sdpa_residual
PiperOrigin-RevId: 746397395
2025-04-11 03:51:40 -07:00
cjkkkk
d19a458f32 fix docstring 2025-04-08 22:48:47 +00:00
Dan Foreman-Mackey
5a3fc606d4 Deprecate public export of mlir.custom_call.
PiperOrigin-RevId: 744722183
2025-04-07 07:58:20 -07:00
Peter Hawkins
70485e31b9 Remove accidental exports jax.interpreters.mlir.{hlo,func_dialect}.
These are available via jax.extend.mlir.dialects.

No deprecation period because jax.interpreters.mlir is not a stable API.

PiperOrigin-RevId: 744712537
2025-04-07 07:20:24 -07:00
cjkkkk
5e0ccb40d6 add option to expose attention residual to user 2025-04-02 22:55:58 +00:00
Shu Wang
aaa3ebfb8a
Add optimization barrier. 2025-03-31 12:05:30 -05:00
shuw
c7d04cc75a Improve based on review 2 2025-03-27 05:09:25 +00:00
shuw
1fe24ca755 Improve based on review 1 2025-03-20 23:26:21 +00:00
shuw
549c669451 Straight-through estimator for nvfp4 2025-03-20 19:30:14 +00:00
shuw
f9aef8a189 Support nvfp4 2025-03-11 19:33:25 +00:00
jax authors
f3b2c84126 Merge pull request #26627 from Cjkkkk:remove_fmha_rewriter
PiperOrigin-RevId: 733690769
2025-03-05 05:20:25 -08:00
jax authors
c7ca35fe32 Merge pull request #26345 from wenscarl:scaled_matmul
PiperOrigin-RevId: 731865430
2025-02-27 14:24:48 -08:00
Shu Wang
7f0a5bc83e
Add apache header. 2025-02-26 15:26:56 -06:00
shuw
17088e9025 Improve after review # 2 2025-02-26 04:48:25 +00:00
shuw
bfb9d3ca4b Improve based on comment # 1 2025-02-21 17:32:57 +00:00
cjkkkk
3a80080392 fix unit tests to not use fmha rewriter 2025-02-20 00:41:04 +00:00
Shu Wang
ae111f7c97
Rename custom-call name. 2025-02-19 16:46:44 -06:00
Shu Wang
4a395956cb
Improve comments. 2025-02-13 09:14:44 -06:00
shuw
332af58765 block_scale_config 2025-02-13 04:35:06 +00:00
shuw
061d4acbfb Scaled matmul for mxfp8 2025-02-05 23:25:51 +00:00
cjkkkk
553199e4dc disable head dim 256 on bw now 2025-02-05 22:07:48 +00:00
cjkkkk
8c4d6d6903 fix lint 2025-02-03 06:09:05 +00:00
cjkkkk
ba6b1fdd09 address lint and typecheck 2025-01-30 22:12:26 +00:00
cjkkkk
bf4c3a77da address comments 2025-01-27 20:04:50 +00:00
cjkkkk
28b642aa22 add segment packing 2025-01-09 19:25:51 +00:00
kaixih
307ea87a8d support head size of 256
Test large head size only on hopper+ gpus

Test large head size only on cudnn 9.5+
2024-12-19 18:38:06 +00:00
wenscarl
c67b651314 Support FP8 for dot_product_attention 2024-12-13 17:56:16 +00:00
Jake VanderPlas
6541a62099 jax.core: deprecate a number of APIs 2024-12-10 11:11:32 -08:00
jax authors
95029abc18 drop compute capability check
PiperOrigin-RevId: 700052796
2024-11-25 11:42:56 -08:00
jax authors
423cd2ad5e Simplified conditional in flash attention.
PiperOrigin-RevId: 691972341
2024-10-31 16:28:11 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
kaixih
d29a757e30 fix bwd batcher for unsupported dbias 2024-09-23 17:43:25 +00:00
Ilia Sergachev
85d792a92d Add cudnn_fusion decorator lowering computations to XLA cuDNN fusions. 2024-09-05 01:25:54 +02:00
Sergei Lebedev
ccabd21084 Fixed rules where `sliding_window_length` was not forwarded
This is follow up to #23284.

PiperOrigin-RevId: 670531634
2024-09-03 06:24:01 -07:00
jax authors
cf936a6d20 Merge pull request #23284 from Cjkkkk:sliding_window_attention
PiperOrigin-RevId: 670160204
2024-09-02 03:54:53 -07:00
cjkkkk
a64b9a543e add sliding window attn 2024-08-30 22:43:05 +00:00
jax authors
2785a08ca9 Merge pull request #22882 from wenscarl:attn_layout_fix
PiperOrigin-RevId: 668636119
2024-08-28 15:44:23 -07:00
Jake VanderPlas
68be5b5085 CI: update ruff to v0.6.1 2024-08-27 14:54:11 -07:00
shuw
8105930a94 Add test 2024-08-08 19:13:30 +00:00
shuw
4d9d622dda Fix check_is_flash_attention 2024-08-05 15:04:31 -07:00
kaixih
6ff6501aa2 Init commit 2024-08-01 19:39:34 +00:00
cjkkkk
c8b474e1f2 add sm86/sm89 2024-07-24 21:12:37 +00:00
Shanbin Ke
ab9fc2d839 PR #22404: [cuDNN SDPA] fix bias sharding check and dbias calculation
Imported from GitHub PR https://github.com/google/jax/pull/22404

* only check bias batch/num_head sharding spec if present. Both dims could be broadcast.
* dbias calculation is incorrect in spmd and all_reduce is required.
Copybara import of the project:

--
cb81b80626bcf17db875bad5526cd2f24c989049 by cjkkkk <ske@nvidia.com>:

fix sharding

Merging this change closes #22404

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22404 from Cjkkkk:fix_bias_sharding_dbias_all_reduce cb81b80626bcf17db875bad5526cd2f24c989049
PiperOrigin-RevId: 653335832
2024-07-17 13:07:15 -07:00
Sebastian Bodenstein
c9534b315e Raise NotImplementedError instead of assert for unsupported Q dtype in fused attention.
This currently causes incorrect behaviour for jax.nn.dot_product_attention: it should raise an error rather than failing with an assert.

PiperOrigin-RevId: 650621750
2024-07-09 07:37:13 -07:00
Kaixi Hou
df6080f346 PR #21371: [NVIDIA] Add new SDPA API to jax.nn
Imported from GitHub PR https://github.com/google/jax/pull/21371

Attention plays a crucial role in modern transformer-based models. While there exist various variants, they generally follow the same workflow. Examples include the typical multi-head attention (MHA), global query attention (GQA), and multi-query attention (MQA). Additionally, new implementations like the Flash Attention algorithm aim to enhance the utilization of accelerator devices. For instance, NVIDIA cuDNN supports Flash Attention and, through its API, can result in a 1.3x end-to-end speedup for training large language models based on GPT alone.

This PR proposes introducing a new API in the `jax.nn` module to handle attention. It will first try to use the cudnn flash attention execution path when the config is compatible. Otherwise it falls back to a jax implementation.

cc. @nluehr @Cjkkkk @cliffwoolley

Copybara import of the project:

--
39a11d91632aab1af5aeec1e92990a7aaeea0cca by kaixih <kaixih@nvidia.com>:

Add new SDPA API to jax.nn

Merging this change closes #21371

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21371 from kaixih:jax_sdpa_dev 39a11d91632aab1af5aeec1e92990a7aaeea0cca
PiperOrigin-RevId: 650225872
2024-07-08 06:16:04 -07:00
Jake VanderPlas
f851099649 Fix mypy error 2024-07-07 19:45:50 -07:00
Shanbin Ke
ceea8dc3fe PR #22078: [cuDNN SDPA] combine mask with bias to support public SDPA API
Imported from GitHub PR https://github.com/google/jax/pull/22078

* cuDNN SDPA does not support mask input any more, therefore we combine the bias and mask manually to align with public SDPA API design.
Copybara import of the project:

--
f99c152cb968205e205b6d7de867f09b3b275541 by cjkkkk <ske@nvidia.com>:

combine mask with bias

Merging this change closes #22078

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22078 from Cjkkkk:sdpa_combine_mask_bias f99c152cb968205e205b6d7de867f09b3b275541
PiperOrigin-RevId: 650059019
2024-07-07 14:42:45 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00