68 Commits

Author SHA1 Message Date
cjkkkk
f9586737dc init 2024-05-24 22:26:20 +00:00
Jake VanderPlas
329ab036ee CI: fix mypy error 2024-05-20 13:23:15 -07:00
Shanbin Ke
06d2e489eb Copybara import of the project:
--
f625317cc80639178882316df6f8775294adc6b7 by cjkkkk <ske@nvidia.com>:

init

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317cc80639178882316df6f8775294adc6b7
PiperOrigin-RevId: 635518631
2024-05-20 11:31:15 -07:00
kaixih
0489eee632 Support BNTH input formats 2024-04-03 20:48:37 +00:00
Cjkkkk
204ee7ff0b add is_training && fix seqlen/head_dim checks 2024-03-14 14:34:40 -07:00
Jake VanderPlas
85f205bdc7 typing: fix incorrect tuple annotations 2024-02-26 10:53:19 -08:00
Benjamin Chetioui
5da43a4c55 [XLA:GPU] Fix misspelled cuDNN custom call targets.
PiperOrigin-RevId: 609024769
2024-02-21 09:35:03 -08:00
jax authors
7b05bbdda0 Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
2024-02-12 16:11:37 -08:00
Cjkkkk
916e53a8a2 add keyword-only argument & fix scale issue 2024-02-09 09:05:09 -08:00
Cjkkkk
59307e9625 add jax.cudnn & add check for bias/mask sharding 2024-02-09 09:05:09 -08:00
Cjkkkk
49f1537f98 rename tests with more descriptive name & Unify SDPA API 2024-02-09 09:05:09 -08:00
Cjkkkk
40eb11bc79 replace pjit with jit and only allow shardings on batch/head dim 2024-02-09 09:05:08 -08:00
Cjkkkk
5708fb955b address some format issues 2024-02-09 09:05:08 -08:00
Cjkkkk
6957d26dd3 add newline 2024-02-09 09:05:08 -08:00
Cjkkkk
145cbb55d8 add license 2024-02-09 09:05:08 -08:00
Cjkkkk
c49a770a6b add __init__.py 2024-02-09 09:05:08 -08:00
Cjkkkk
2d346149de fix test 2024-02-09 09:05:08 -08:00
Cjkkkk
9b8a100039 add unit test and move to _src/cudnn dir 2024-02-09 09:05:08 -08:00