rocm_jax/jax/experimental
Yash Katariya d02f28199b Clean up pjit after jax.Array
* Remove {in|out}_positional_semantics from pjit_p.bind
* Remove `in_is_global` from lower_sharding_computation
* Remove local_to_global and global_to_local
* Clean up some arguments of sharded_lowering since they are not needed

PiperOrigin-RevId: 517469390
2023-03-17 11:53:00 -07:00
..
2023-03-17 11:53:00 -07:00
2023-03-17 11:53:00 -07:00
2023-03-17 11:53:00 -07:00
2023-03-16 15:47:28 -07:00
2023-03-17 11:53:00 -07:00