[jax:custom_partitioning] Make propagate_user_sharding default to None.

Revise documentation for sharding_rule and add a link to jax-shardy-guide.

PiperOrigin-RevId: 721001922
This commit is contained in:
Bixia Zheng 2025-01-29 09:14:00 -08:00 committed by jax authors
parent 955e7c4793
commit 20843643ab

View File

@ -288,17 +288,12 @@ class custom_partitioning:
f.def_partition(partition, propagate_user_sharding,
infer_sharding_from_operands=infer_sharding_from_operands,
sharding_rule='i j -> 'i j')
When config.use_shardy_partitioner.value is True, the sharding_rule is
used; otherwise, propagate_user_sharding and infer_sharding_from_operands
are used.
Instead of using an Einsum-like notation string, sharding_rule can also be
a SdyShardingRule object, such as sharding_rule=SdyShardingRule(("i", "j"), ("i", "j")).
The args to ``def_partition`` are as follows:
* ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag)
and returns a suggestion for a new `NamedSharding`. The default
implementation is just to return the suggested sharding.
and returns a suggestion for a new `NamedSharding`. The default value is None.
A trivial implementation is just to return the input sharding.
* ``partition``: Callable which takes the SPMD suggested partition shapes and
partition specs and returns the mesh, a per-shard lowering function, and the final
input and output sharding specs (the SPMD partitioner will repartition the
@ -312,7 +307,13 @@ class custom_partitioning:
* ``sharding_rule``: an SdyShardingRule object, an Einsum-like notation string
that describes the sharding rule, or a Callable that produces either of
these. We borrow the idea from the einops.rearrange string , to use a space
separator between factors and allow multiple letters factor names.
separator between factors and allow multiple letters factor names. See
[jax-shardy-guide](https://colab.sandbox.google.com/github/openxla/shardy/blob/main/docs/getting_started_jax.ipynb)
for more details and examples on how to use this.
When config.use_shardy_partitioner.value is True, `sharding_rule` is used;
otherwise, `propagate_user_sharding` and `infer_sharding_from_operands` are
used.
Positional arguments can be specified as static using static_argnums. JAX uses
:code:`inspect.signature(fun)` to resolve these positional arguments.
@ -451,7 +452,7 @@ class custom_partitioning:
__getattr__: Any = custom_api_util.forward_attr
def def_partition(self, partition, infer_sharding_from_operands,
def def_partition(self, partition, infer_sharding_from_operands=None,
propagate_user_sharding=None, decode_shardings=True,
sharding_rule=None):
if config.use_shardy_partitioner.value: