This website requires JavaScript.
Explore
Help
Sign In
mirrors
/
rocm_jax
Watch
1
Star
0
Fork
0
You've already forked rocm_jax
mirror of
https://github.com/ROCm/jax.git
synced
2025-04-25 03:06:04 +00:00
Code
Issues
Packages
Projects
Releases
Wiki
Activity
rocm_jax
/
jax
/
experimental
/
pallas
History
Sharad Vikram
1c796c0ff4
[Pallas] Automatically turn mesh indices -> physical ids for remote DMAs
...
PiperOrigin-RevId: 570221510
2023-10-02 17:04:15 -07:00
..
ops
Allow head_dim <= 128 in Pallas:TPU flash attention implementation
2023-09-25 11:25:29 -07:00
__init__.py
[Pallas] Refactor memory space handling
2023-09-07 17:08:57 -07:00
gpu.py
[Pallas] Upstream pallas to JAX
2023-08-01 16:43:13 -07:00
tpu.py
[Pallas] Automatically turn mesh indices -> physical ids for remote DMAs
2023-10-02 17:04:15 -07:00