87 Commits

Author SHA1 Message Date
Peter Hawkins
e02faabfb2 Replace references to jax.readthedocs.io with docs.jax.dev.
PiperOrigin-RevId: 745156931
2025-04-08 08:33:49 -07:00
Yash Katariya
cccc34dc23 Raise an error if the type passed to axis_types argument of Mesh and AbstractMesh is not jax.sharding.AxisType.
PiperOrigin-RevId: 744602037
2025-04-06 23:38:09 -07:00
Yash Katariya
5b3e419515 Add auto_axes, explicit_axes and manual_axes properties to Mesh and AbstractMesh
PiperOrigin-RevId: 743767895
2025-04-03 18:35:28 -07:00
Yash Katariya
c1904dc7eb Update the docstring to mesh to use computation follows data and jax.jit APIs. Fixes https://github.com/jax-ml/jax/issues/27390
PiperOrigin-RevId: 740104692
2025-03-24 16:07:12 -07:00
Yash Katariya
663ef7ae01 Check the type of mesh in use_abstract_mesh and use_concrete_mesh
PiperOrigin-RevId: 738190879
2025-03-18 16:57:40 -07:00
Ayaka
9b0ace4a11 Support error checking in explicit mode
PiperOrigin-RevId: 737051146
2025-03-14 18:58:26 -07:00
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Peter Hawkins
6fa98fc0a4 Use "x is y" rather than "id(x) == id(y)".
The latter involves at least two object constructions.

PiperOrigin-RevId: 736878098
2025-03-14 08:54:46 -07:00
Peter Hawkins
1507754408 Precompute the __hash__ of AbstractMesh.
We use this frequently and it saves time to precompute it.

PiperOrigin-RevId: 736650750
2025-03-13 15:01:31 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
Yash Katariya
47480b4493 Add a set_mesh API to jax.sharding. set_mesh sets the sharding and never unsets it i.e. this is just __enter__ of a ctx manager without __exit__
PiperOrigin-RevId: 736261724
2025-03-12 14:12:47 -07:00
Matthew Johnson
7c2f842353 shard_map and other fixes to direct-linearize
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2025-03-07 21:02:40 +00:00
Emily Fertig
82124da5cd Redefine is_fully_addressable in shardings to support zero local devices for McJAX.
PiperOrigin-RevId: 731526750
2025-02-26 18:17:35 -08:00
Peter Hawkins
256e37af5f Port many uses of contextlib.contextdecorator to explicit context manager classes.
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.

PiperOrigin-RevId: 730939406
2025-02-25 10:31:05 -08:00
Yash Katariya
7c4fe2a7cc [sharding_in_types] Allow auto_axes and explicit_axes to take numpy arrays, python scalars.
PiperOrigin-RevId: 729729215
2025-02-21 18:49:02 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Yash Katariya
8bcbf585df Make device_put resharding on single device array input work under use_mesh. Fixes https://github.com/jax-ml/jax/issues/26552
PiperOrigin-RevId: 728382461
2025-02-18 15:22:39 -08:00
Bart Chrzaszcz
7bfa420d6a Cache shape property on Mesh.
This is already done on `AbstractMesh`. Should be done since `OrderedDict`s are expensive to create.

PiperOrigin-RevId: 727705113
2025-02-16 21:43:03 -08:00
Yash Katariya
0944e5202e Create _BaseMesh so that properties can be shared between Mesh and AbstractMesh so that code is not duplicated
PiperOrigin-RevId: 726193613
2025-02-12 14:14:48 -08:00
Yash Katariya
d58c3a4722 [sharding_in_types] Fix some properties that assumed axis_types always existed.
PiperOrigin-RevId: 726187278
2025-02-12 13:57:19 -08:00
Yash Katariya
2d01df760b [sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent

* canonicalization does not happen for avals on an empty mesh

* jax.jit does not set abstract mesh context anymore before tracing

* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode

* Even if use_mesh is not used in explicit sharding mode, computation follows data works!

* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)

* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.

As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.

PiperOrigin-RevId: 726097292
2025-02-12 10:03:01 -08:00
Yash Katariya
b4b4a98db7 [sharding_in_types] When caching mesh with axis_types, make sure the data structure is (axis_size, axis_names, tuple(axis_types))
PiperOrigin-RevId: 726064530
2025-02-12 08:23:52 -08:00
Peter Hawkins
f21b0f03b4 Speed up NamedSharding construction.
* Compute the size of a mesh eagerly. We're almost always going to need this, because NamedSharding's constructor asks for it.
* Speed up mesh equality. It's likely we have only one mesh, and the identity equality test will hit. Do it first.
* don't call _prepare_axis_resources in ParsedPartitionSpec construction. This does a bunch of pointless tree flattening and list manipulation but we know we have exactly one PartitionSpec and can directly do the check we need, which is _check_unique_resources.
* only call _check_unique_resources on PartitionSpecs; it's easy to avoid doing it in other cases and then we don't need a bunch of isinstance checks.
* avoid use of collections.Counter when checking for unique resources. collections.Counter has a surprisingly slow isinstance test.

PiperOrigin-RevId: 724431847
2025-02-07 12:20:51 -08:00
Yash Katariya
307006e194 Set the mesh as manual during partial_eval_custom in shard_map so that _add_reshapes happens under the correct mesh.
PiperOrigin-RevId: 723268798
2025-02-04 16:36:08 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Yash Katariya
d28c3fa409 Replace Hidden/Visible/Collective AxisTypes names with Auto/Explicit/Manual.
PiperOrigin-RevId: 719561729
2025-01-24 23:21:13 -08:00
Yash Katariya
704b2e5fba [sharding_in_types] Make vmap work with shard_map + pallas
PiperOrigin-RevId: 718578207
2025-01-22 16:48:32 -08:00
Yash Katariya
695c02b1c4 [sharding_in_types] Rename sharding_cast to mesh_cast and add a few restrictions:
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!

* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.

* src and dst mesh axis_names and axis_sizes should be the same.

TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
2025-01-17 10:53:43 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Yash Katariya
c72ed260fe [sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
Yash Katariya
a817f532b4 [sharding_in_types] Introduce auto_mode, user_mode, auto_mode_ctx and user_mode_ctx as **private** APIs to make writing auto/user sharding in types code way easier and noise-free.
These can be made public in the future under different names.

PiperOrigin-RevId: 714169304
2025-01-10 14:14:25 -08:00
Sergei Lebedev
90201ce2b7 Removed leftover mentions of xmap from the code
PiperOrigin-RevId: 713202387
2025-01-08 01:39:38 -08:00
Yash Katariya
e854f1657a Allow P.UNCONSTRAINED in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs.
This also simplifies the `impl` rule of `sharding_cast`.

PiperOrigin-RevId: 707349491
2024-12-17 19:18:24 -08:00
Yash Katariya
473e2bf527 Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes.

Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on.

PiperOrigin-RevId: 707128096
2024-12-17 09:17:21 -08:00
George Necula
afcb62ea20 [export] Expand exporting to work with AbstractMesh.
This is a follow up from #25640 that enabled lowering with
AbstractMesh.

This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
2024-12-16 10:30:46 +02:00
Yash Katariya
41f490aef4 [sharding_in_types] Default axis_types to Auto for all axis_names if user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.
PiperOrigin-RevId: 704944255
2024-12-10 20:20:23 -08:00
Yash Katariya
b5e4fd161d [sharding_in_types] Enforce AxisTypes to always exist if set_mesh is used.
Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
2024-12-10 18:03:21 -08:00
Loren Maggiore
208194f9a5 context manager methods for AbstractMesh to appease type checker.
PiperOrigin-RevId: 702890537
2024-12-04 15:58:03 -08:00
Yash Katariya
a735bf83e5 Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
2024-12-04 14:04:25 -08:00
Yash Katariya
653f65452d Fix the broken behavior of not resetting the abstract_mesh and device_context properly during __exit__.
PiperOrigin-RevId: 702762477
2024-12-04 09:59:23 -08:00
Yash Katariya
0d2dfea4b1 Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on.

PiperOrigin-RevId: 700537898
2024-11-26 20:01:04 -08:00
Yash Katariya
6763fcfb4e Fix a weird interaction with set_local and empty tuples passed to it.
PiperOrigin-RevId: 700392735
2024-11-26 10:50:05 -08:00
Yash Katariya
627debc78b Create a null_mesh_context internal context manager to handle null contexts properly.
PiperOrigin-RevId: 700167406
2024-11-25 18:32:05 -08:00
Yash Katariya
40fc6598f9 [sharding_in_types] Make flash_attention forward pass in TPU pallas work nicely with sharding in types. Backward pass is still busted which I will fix in follow up CLs.
Set the abstract mesh context manager at the jit tracing boundary by looking at the mesh on the avals. In the future, this context manager will be user settable too.

Abstract mesh context manager is a new context manager with a new context variable and new trace_context entry which governs the cache behavior. If the abstract mesh context manager is not set, the default is `None`.

PiperOrigin-RevId: 698493184
2024-11-20 13:07:30 -08:00
Yash Katariya
8525ef2b23 [sharding_in_types] Don't emit a wsc under full manual mode to avoid increasing HLO size by a lot
PiperOrigin-RevId: 697048126
2024-11-15 17:42:16 -08:00
Yash Katariya
9a0e9e55d8 [sharding_in_types] Handle collective axes in lowering rules more generally. If any axis is collective, set all dims of aval to unspecified dims in wrap_with_sharding_op.
Also lower shardings with `Collective` axes correctly to HloSharding.

PiperOrigin-RevId: 696703030
2024-11-14 17:32:01 -08:00
Yash Katariya
05716b58b0 [sharding_in_types] Support shard_map with sharding in types. Right now only full manual mode is supported.
This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`.

In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh.

PiperOrigin-RevId: 696559375
2024-11-14 09:58:03 -08:00