This website requires JavaScript.
Explore
Help
Sign In
mirrors
/
rocm_jax
Watch
1
Star
0
Fork
0
You've already forked rocm_jax
mirror of
https://github.com/ROCm/jax.git
synced
2025-04-18 04:46:06 +00:00
Code
Issues
Packages
Projects
Releases
Wiki
Activity
rocm_jax
/
jax
/
_src
/
interpreters
History
Yash Katariya
84156f359f
Add identity jit tests to go from pinned_host -> device and vice versa
...
PiperOrigin-RevId: 620114420
2024-03-28 18:20:32 -07:00
..
__init__.py
Reapply: move
jax.interpreters.ad
to
jax._src.interpreters.ad
2023-02-02 09:29:05 -08:00
ad.py
[xmap-removal] remove reduce_axes from grad / vjp / backward_pass
2024-02-25 15:50:54 -08:00
batching.py
Reland
https://github.com/google/jax/pull/10573
.
2024-02-16 05:57:12 -08:00
mlir.py
Add identity jit tests to go from pinned_host -> device and vice versa
2024-03-28 18:20:32 -07:00
partial_eval.py
add forwarding optimization test for shard_map
2024-03-15 15:11:16 -07:00
pxla.py
Allow allow_spmd_propagation_to_output to be generated for outputs annotated with pjit.AUTO
2024-03-27 12:04:03 -07:00
xla.py
Raise a better error message when an invalid input is passed to jit call.
2024-03-21 17:46:32 -07:00