Yash Katariya 958234a9c1 Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this:
```
with mesh:
  out = pjit(lambda: 1)()
```

The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead.

This is also required for `Shardy` integration.

PiperOrigin-RevId: 658842350
2024-08-02 11:04:48 -07:00
..
2024-08-02 10:41:22 -07:00
2024-05-28 23:23:51 -04:00
2024-07-30 05:39:19 +02:00
2024-05-25 17:46:01 +00:00
2024-07-31 13:23:12 +03:00
2024-06-25 09:02:32 -07:00
2024-07-29 17:17:22 +00:00
2024-07-15 12:54:00 -07:00