But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.
PiperOrigin-RevId: 744888557
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.
Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
Once the compiler supports it, we can make it public and allow users to pass those values. Right now, only `major_to_minor` is supported.
But a valid question is why even keep them as arguments in the constructor?
It's because we need to translate `PjRtLayout` which we get from the executable to `DeviceLocalLayout` and preserve the `tiling` and `sub_byte_element_size_in_bits` info that we get from the compiler. This has helped catch bugs before when the compiler was not doing the right thing in layout propagation pass.
PiperOrigin-RevId: 651644934
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.
memory space is exposed via JAX memories API so it doesn't have to be in the layout API.
Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.
Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.
PiperOrigin-RevId: 647487510
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.
Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
We expose 3 modes:
* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.
* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.
* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.
Public API coming soon.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036