3 Commits

Author SHA1 Message Date
Bixia Zheng
0c3de93b79 [jax:custom_partitioning] Support SdyShardingRule with multiple leading
batching dimension groups.

Previously, we allow the use of ellipsis ... in the Einsum like notation to
represent leading batching dimensions in one group of operands and results. We
now allow the use of ellipsis optionally followed by a single digit, such as
...2, to represent leading batching dimensions for multiple groups of operands
and results.

Add tests.

PiperOrigin-RevId: 718875251
2025-01-23 08:20:38 -08:00
Bixia Zheng
d4899f7b9b [jax:custom_partitioning] Make SdyShardingRule a user facing class.
Move the parsing of a sharding rule string to a free function
str_to_sdy_sharding_rule. Move the building of the MLIR sharding rule to a free
function sdy_sharding_rule_to_mlir.

PiperOrigin-RevId: 704818640
2024-12-10 13:05:43 -08:00
Bixia Zheng
2a4a0e8d6f [jax:custom_partitioning] Implement SdyShardingRule to support
Shardy custom_partitioning.

The parsing of the sharding rule string very closely follows how einops parses
their rules in einops/parsing.py.

When a SdyShardingRule object is constructed, we check the syntax of the Einsum
like notation string and its consistency with the user provided factor_sizes,
and report errors accordingly. This is done during f.def_partition.

When SdyShardingRule.build is called, during JAX to MLIR lowering, we check
the consistency between the Einsum like notation string, the factor_sizes
and the MLIR operation, and report errors accordingly.

PiperOrigin-RevId: 703187962
2024-12-05 11:33:23 -08:00