mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax:custom_partitioning] Make propagate_user_sharding default to None.
Revise documentation for sharding_rule and add a link to jax-shardy-guide. PiperOrigin-RevId: 721001922
This commit is contained in:
parent
955e7c4793
commit
20843643ab
@ -288,17 +288,12 @@ class custom_partitioning:
|
||||
f.def_partition(partition, propagate_user_sharding,
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
sharding_rule='i j -> 'i j')
|
||||
When config.use_shardy_partitioner.value is True, the sharding_rule is
|
||||
used; otherwise, propagate_user_sharding and infer_sharding_from_operands
|
||||
are used.
|
||||
Instead of using an Einsum-like notation string, sharding_rule can also be
|
||||
a SdyShardingRule object, such as sharding_rule=SdyShardingRule(("i", "j"), ("i", "j")).
|
||||
|
||||
The args to ``def_partition`` are as follows:
|
||||
|
||||
* ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag)
|
||||
and returns a suggestion for a new `NamedSharding`. The default
|
||||
implementation is just to return the suggested sharding.
|
||||
and returns a suggestion for a new `NamedSharding`. The default value is None.
|
||||
A trivial implementation is just to return the input sharding.
|
||||
* ``partition``: Callable which takes the SPMD suggested partition shapes and
|
||||
partition specs and returns the mesh, a per-shard lowering function, and the final
|
||||
input and output sharding specs (the SPMD partitioner will repartition the
|
||||
@ -312,7 +307,13 @@ class custom_partitioning:
|
||||
* ``sharding_rule``: an SdyShardingRule object, an Einsum-like notation string
|
||||
that describes the sharding rule, or a Callable that produces either of
|
||||
these. We borrow the idea from the einops.rearrange string , to use a space
|
||||
separator between factors and allow multiple letters factor names.
|
||||
separator between factors and allow multiple letters factor names. See
|
||||
[jax-shardy-guide](https://colab.sandbox.google.com/github/openxla/shardy/blob/main/docs/getting_started_jax.ipynb)
|
||||
for more details and examples on how to use this.
|
||||
|
||||
When config.use_shardy_partitioner.value is True, `sharding_rule` is used;
|
||||
otherwise, `propagate_user_sharding` and `infer_sharding_from_operands` are
|
||||
used.
|
||||
|
||||
Positional arguments can be specified as static using static_argnums. JAX uses
|
||||
:code:`inspect.signature(fun)` to resolve these positional arguments.
|
||||
@ -451,7 +452,7 @@ class custom_partitioning:
|
||||
|
||||
__getattr__: Any = custom_api_util.forward_attr
|
||||
|
||||
def def_partition(self, partition, infer_sharding_from_operands,
|
||||
def def_partition(self, partition, infer_sharding_from_operands=None,
|
||||
propagate_user_sharding=None, decode_shardings=True,
|
||||
sharding_rule=None):
|
||||
if config.use_shardy_partitioner.value:
|
||||
|
Loading…
x
Reference in New Issue
Block a user