Yash Katariya 34611be53d Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
* Handled transpose of `dot_general` correctly with shardings
* Handled transpose of `reduce_sum` correctly with shardings
* `ShapedArray.to_tangent_aval` now sets the sharding of the tangent (not handling unreduced yet).
* `ConcreteArray.aval` correctly sets the sharding which is extracted from the `val` attribute.
* (Paired with Dougal!) Added sharding rule for `reshape_p` only when singleton dims are added/removed.
* Added sharding rule for `select_n_p` because it gets called during `jax.grad` of minformer.
* Added `sharding` attribute to `broadcast_in_dim` because we need to provide the correct sharding to it during `full` and transpose of `reduce_sum`.

PiperOrigin-RevId: 689837320
2024-10-25 10:35:25 -07:00
..
2024-10-01 02:01:18 +00:00
2024-05-25 17:46:01 +00:00
2024-10-07 12:27:35 -07:00
2024-10-03 00:27:31 -07:00
2024-10-10 08:07:35 -07:00
2024-10-18 08:54:36 -07:00
2024-07-15 12:54:00 -07:00