This website requires JavaScript.
Explore
Help
Sign In
mirrors
/
rocm_jax
Watch
1
Star
0
Fork
0
You've already forked rocm_jax
mirror of
https://github.com/ROCm/jax.git
synced
2025-04-24 15:56:09 +00:00
Code
Issues
Packages
Projects
Releases
Wiki
Activity
rocm_jax
/
jax
/
interpreters
History
Yash Katariya
864d640ee1
Set committed=True for nested pjits/with_sharding_constraint if any jaxpr_sharding is not UNSPECIFIED.
...
PiperOrigin-RevId: 503833657
2023-01-22 14:07:03 -08:00
..
__init__.py
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
2022-09-22 12:27:19 -07:00
ad.py
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00
batching.py
Add batch_jaxpr2 which tells the caller where batch dims are.
2023-01-12 21:16:59 -08:00
mlir.py
Expose fp8 in jax dtypes and mlir builder.
2023-01-13 18:12:12 -08:00
partial_eval.py
Add forwarding support to pjit which was introduced as an optimization. The inputs that are forwarded to outputs are pruned from the outputs of a known_jaxpr.
2023-01-20 18:04:26 -08:00
pxla.py
Set committed=True for nested pjits/with_sharding_constraint if any jaxpr_sharding is not UNSPECIFIED.
2023-01-22 14:07:03 -08:00
xla.py
Error on numpy masked array inputs.
2022-12-27 15:42:49 -08:00