This website requires JavaScript.
Explore
Help
Sign In
mirrors
/
rocm_jax
Watch
1
Star
0
Fork
0
You've already forked rocm_jax
mirror of
https://github.com/ROCm/jax.git
synced
2025-04-19 05:16:06 +00:00
Code
Issues
Packages
Projects
Releases
Wiki
Activity
rocm_jax
/
jax
/
interpreters
History
Yash Katariya
e21c29476d
Add batch_jaxpr2 which tells the caller where batch dims are.
...
Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 501746795
2023-01-12 21:16:59 -08:00
..
__init__.py
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
2022-09-22 12:27:19 -07:00
ad.py
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00
batching.py
Add batch_jaxpr2 which tells the caller where batch dims are.
2023-01-12 21:16:59 -08:00
mlir.py
[jax2tf] Fixed the shape-polymorphic lowering for lax.pad and dynamic_slice
2023-01-11 13:02:48 +01:00
partial_eval.py
add pjit partial_eval_jaxpr_custom rule
2023-01-11 09:30:49 -08:00
pxla.py
Make
jit
a thin wrapper around
pjit
which ignores the mesh context manager (just like how it is today)
2023-01-12 17:24:32 -08:00
xla.py
Error on numpy masked array inputs.
2022-12-27 15:42:49 -08:00