Yash Katariya bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
..
2024-08-01 11:18:19 +01:00
2024-04-13 08:18:33 +01:00
2024-06-17 13:55:46 +05:30