21 Commits

Author SHA1 Message Date
Yash Katariya
0a72e856cf Add **experimental** with_dll_constraint API. This is for cases when the users wants to let SPMD decide the sharding.
But this is a contradiction since layouts apply to device local shape and without knowing the sharding, you can't decide the layout. But there are cases where you don't care what the sharding is, you just want to force a row-major layout (for example). **This API should only be used for those cases**.

PiperOrigin-RevId: 744888557
2025-04-07 16:21:58 -07:00
Jake VanderPlas
8948e6de58 sharding cleanup: use inline checks for unimplemented and auto 2024-10-25 04:22:40 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Yash Katariya
ff3dc0f5fb Add check_compatible_aval checks to Layout. It checks if len(major_to_minor) == len(aval.shape).
PiperOrigin-RevId: 651777179
2024-07-12 08:10:43 -07:00
Yash Katariya
ff18dedf99 Make tiling and sub_byte_element_size_in_bits private arguments of DeviceLocalLayout. This is because XLA does not respect the values passed to it.
Once the compiler supports it, we can make it public and allow users to pass those values. Right now, only `major_to_minor` is supported.

But a valid question is why even keep them as arguments in the constructor?

It's because we need to translate `PjRtLayout` which we get from the executable to `DeviceLocalLayout` and preserve the `tiling` and `sub_byte_element_size_in_bits` info that we get from the compiler. This has helped catch bugs before when the compiler was not doing the right thing in layout propagation pass.

PiperOrigin-RevId: 651644934
2024-07-11 22:06:02 -07:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
Sai-Suraj-27
5564521308 Prefer raising of TypeError for invalid types instead of ValueError. 2024-04-08 13:08:24 +05:30
Yash Katariya
c3f5af7d46 Delete deprecated AOT layouts API.
PiperOrigin-RevId: 622666838
2024-04-07 14:15:36 -07:00
Yash Katariya
5cbb26f36d Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.
This is in preparation to get `jax.jit` to accept `Layout`.

PiperOrigin-RevId: 621697750
2024-04-03 18:37:32 -07:00
Yash Katariya
d790c88da9 Rename layout.AUTO to DeviceLocalLayout.AUTO
PiperOrigin-RevId: 621684185
2024-04-03 17:23:35 -07:00
Yash Katariya
92326dbc71 Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put.

Note: This currently only works on TPU.
PiperOrigin-RevId: 621668247
2024-04-03 16:13:31 -07:00
Yash Katariya
6557f680fd Rename SpecifiedLayout to DeviceLocalLayout
PiperOrigin-RevId: 620934348
2024-04-01 13:19:46 -07:00
Yash Katariya
25d01e983c [Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd

PiperOrigin-RevId: 618878870
2024-03-25 10:08:43 -07:00
jax authors
cd79e71d85 Reverts 0e092a77067dbbce33cfd6d54a46e743b779919b
PiperOrigin-RevId: 618127324
2024-03-22 03:46:09 -07:00
Yash Katariya
0e092a7706 Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
PiperOrigin-RevId: 618050680
2024-03-21 21:02:40 -07:00
Yue Sheng
291a5cd3e0 [PJRT][IFRT] Update PJRT, IFRT, and Py executable getters to return PjRtLayouts
PiperOrigin-RevId: 617889924
2024-03-21 10:30:57 -07:00
Yash Katariya
c8ef37507b Make the SpecifiedLayout class opaque.
Also need to enabling pickling to xc.Layout so that AOT serialization continues to work.

PiperOrigin-RevId: 583684299
2023-11-18 15:17:16 -08:00
Yash Katariya
439b89e47f Remove DefaultLayout and make None same as DefaultLayout
PiperOrigin-RevId: 583221970
2023-11-16 18:01:27 -08:00
Yash Katariya
5c3da219c0 Add a private API to allow setting layouts on jitted computations.
We expose 3 modes:

* `SpecifiedLayout`: User specifies the `minor_to_major` field of the layout. Tiling not exposed yet.

* `DefaultLayout`: PJRT chooses the layout. It defaults to the current behavior.

* `AUTO`: Compiler chooses the layout. This field is not a layout per se. It's a request to get the layout from the compiler. This field cannot be on an Array or other data types. It can only be on jit.

Public API coming soon.

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 582692036
2023-11-15 08:48:53 -08:00