batching dimension groups.
Previously, we allow the use of ellipsis ... in the Einsum like notation to
represent leading batching dimensions in one group of operands and results. We
now allow the use of ellipsis optionally followed by a single digit, such as
...2, to represent leading batching dimensions for multiple groups of operands
and results.
Add tests.
PiperOrigin-RevId: 718875251
Move the parsing of a sharding rule string to a free function
str_to_sdy_sharding_rule. Move the building of the MLIR sharding rule to a free
function sdy_sharding_rule_to_mlir.
PiperOrigin-RevId: 704818640
Shardy custom_partitioning.
The parsing of the sharding rule string very closely follows how einops parses
their rules in einops/parsing.py.
When a SdyShardingRule object is constructed, we check the syntax of the Einsum
like notation string and its consistency with the user provided factor_sizes,
and report errors accordingly. This is done during f.def_partition.
When SdyShardingRule.build is called, during JAX to MLIR lowering, we check
the consistency between the Einsum like notation string, the factor_sizes
and the MLIR operation, and report errors accordingly.
PiperOrigin-RevId: 703187962