cjkkkk
|
f9586737dc
|
init
|
2024-05-24 22:26:20 +00:00 |
|
Jake VanderPlas
|
329ab036ee
|
CI: fix mypy error
|
2024-05-20 13:23:15 -07:00 |
|
Shanbin Ke
|
06d2e489eb
|
Copybara import of the project:
--
f625317cc80639178882316df6f8775294adc6b7 by cjkkkk <ske@nvidia.com>:
init
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317cc80639178882316df6f8775294adc6b7
PiperOrigin-RevId: 635518631
|
2024-05-20 11:31:15 -07:00 |
|
kaixih
|
0489eee632
|
Support BNTH input formats
|
2024-04-03 20:48:37 +00:00 |
|
Cjkkkk
|
204ee7ff0b
|
add is_training && fix seqlen/head_dim checks
|
2024-03-14 14:34:40 -07:00 |
|
Jake VanderPlas
|
85f205bdc7
|
typing: fix incorrect tuple annotations
|
2024-02-26 10:53:19 -08:00 |
|
Benjamin Chetioui
|
5da43a4c55
|
[XLA:GPU] Fix misspelled cuDNN custom call targets.
PiperOrigin-RevId: 609024769
|
2024-02-21 09:35:03 -08:00 |
|
jax authors
|
7b05bbdda0
|
Merge pull request #18814 from Cjkkkk:spda
PiperOrigin-RevId: 606397276
|
2024-02-12 16:11:37 -08:00 |
|
Cjkkkk
|
916e53a8a2
|
add keyword-only argument & fix scale issue
|
2024-02-09 09:05:09 -08:00 |
|
Cjkkkk
|
59307e9625
|
add jax.cudnn & add check for bias/mask sharding
|
2024-02-09 09:05:09 -08:00 |
|
Cjkkkk
|
49f1537f98
|
rename tests with more descriptive name & Unify SDPA API
|
2024-02-09 09:05:09 -08:00 |
|
Cjkkkk
|
40eb11bc79
|
replace pjit with jit and only allow shardings on batch/head dim
|
2024-02-09 09:05:08 -08:00 |
|
Cjkkkk
|
5708fb955b
|
address some format issues
|
2024-02-09 09:05:08 -08:00 |
|
Cjkkkk
|
6957d26dd3
|
add newline
|
2024-02-09 09:05:08 -08:00 |
|
Cjkkkk
|
145cbb55d8
|
add license
|
2024-02-09 09:05:08 -08:00 |
|
Cjkkkk
|
c49a770a6b
|
add __init__.py
|
2024-02-09 09:05:08 -08:00 |
|
Cjkkkk
|
2d346149de
|
fix test
|
2024-02-09 09:05:08 -08:00 |
|
Cjkkkk
|
9b8a100039
|
add unit test and move to _src/cudnn dir
|
2024-02-09 09:05:08 -08:00 |
|