This website requires JavaScript.
Explore
Help
Sign In
mirrors
/
rocm_jax
Watch
1
Star
0
Fork
0
You've already forked rocm_jax
mirror of
https://github.com/ROCm/jax.git
synced
2025-04-16 11:56:07 +00:00
Code
Issues
Packages
Projects
Releases
Wiki
Activity
rocm_jax
/
jax
/
_src
/
pallas
History
Dougal Maclaurin
c36e1f7c1a
Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
...
PiperOrigin-RevId: 691086496
2024-10-29 11:04:31 -07:00
..
mosaic
Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
2024-10-29 11:04:31 -07:00
mosaic_gpu
[Pallas:MGPU] Add FlashAttention3 as an example
2024-10-29 05:21:43 -07:00
triton
Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
2024-10-25 10:35:25 -07:00
__init__.py
[Pallas] Upstream pallas to JAX
2023-08-01 16:43:13 -07:00
BUILD
Clean up BUILD files.
2024-08-26 09:11:17 -07:00
core.py
Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
2024-10-29 11:04:31 -07:00
pallas_call.py
Implements an alternate version of ragged_attention, wherein, the actual attention kernel itself is dense. Meaning, this kernel does not have the compute saving (@when wrapped kernel) or prefetch/index skipping (via index rewriting) as part of the kernel. Rather, the kernel is invoked with a Jumble (A ragged type representation) and pallas takes care of applying the correct work skipping and index rewriting.
2024-10-25 12:07:34 -07:00
primitives.py
Make trace dispatch purely a function of context rather than a function of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.
2024-10-29 11:04:31 -07:00
utils.py
[Pallas TPU] Add lowering for 64 bit
2024-09-19 16:42:45 +01:00