10 Commits

Author SHA1 Message Date
Yash Katariya
88d4bc3d45 Rename AxisTypes enum to AxisType
PiperOrigin-RevId: 736935746
2025-03-14 11:48:21 -07:00
Yash Katariya
e615e2acb3 Raise a better error with more info when we see duplicate axis in a PartitionSpec resulting from a sharding rule.
Previously it was:

`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`

Now it is:

`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`

PiperOrigin-RevId: 736657644
2025-03-13 15:24:10 -07:00
Yash Katariya
c6dcbb6759 [sharding_in_types] Rework the axis_types argument in Mesh and AbstractMesh APIs. The changes are:
1. axis_types now takes a `AxisTypes | tuple[AxisTypes, ...] | None`. It doesn't take a dictionary anymore

2. `jax.make_mesh` also takes the same `axis_types` tuple as in point 1.

PiperOrigin-RevId: 736360041
2025-03-12 20:41:50 -07:00
Yash Katariya
07f192cd48 Merge _check_mesh_resource_axis and _check_axis_type_consistency into 1 function.
PiperOrigin-RevId: 731830347
2025-02-27 12:51:25 -08:00
Yash Katariya
c265568530 Remove parsed_pspec from NamedSharding constructor
PiperOrigin-RevId: 731820173
2025-02-27 12:24:17 -08:00
Yash Katariya
d69da3b012 More cleanups around ParsedPartitionSpec. In a follow up CL, I can remove it from NamedSharding constructor. Deleting ParsedPartitionSpec is remaining but that's after 0.5.2 release.
PiperOrigin-RevId: 731785005
2025-02-27 10:51:04 -08:00
Yash Katariya
034a827a4d Remove _parsed_pspec from everywhere in JAX except for NamedSharding constructor. I'll do that in the next CL since that has a dependency on C++ so needs guards.
PiperOrigin-RevId: 731772222
2025-02-27 10:17:06 -08:00
Yash Katariya
177e1f6ed9 Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec. We need to do this after sharding-in-types to speed up NamedSharding construction and remove a lot of tech debt and unnecessary complexity.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.

* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times

* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.

PiperOrigin-RevId: 731745062
2025-02-27 08:59:25 -08:00
Emily Fertig
82124da5cd Redefine is_fully_addressable in shardings to support zero local devices for McJAX.
PiperOrigin-RevId: 731526750
2025-02-26 18:17:35 -08:00
Yash Katariya
229aa65a3e Split NamedSharding into a separate file called named_sharding.py so that we can import it in core.py and break the cyclic dependency.
PiperOrigin-RevId: 726566863
2025-02-13 11:22:54 -08:00