Adds a debugging message to assert, otherwise the error is pretty cryptic.

PiperOrigin-RevId: 747657234
This commit is contained in:
Mark Sandler 2025-04-14 19:08:48 -07:00 committed by jax authors
parent 4fa3cd91d3
commit 0ed0fb7c54

View File

@ -898,7 +898,11 @@ def parse_flatten_op_sharding(
while dim_size > 1:
axis = next(mesh_axis)
axis_size = mesh_shape[axis]
assert dim_size % axis_size == 0
if dim_size % axis_size != 0:
raise ValueError(
f'{shape=} is incompatible with {mesh_shape=}: '
f'{dim_size=} is not divisible by {axis_size=}.'
)
dim_size //= axis_size
dim_partitions.append(axis)
partitions.append(tuple(dim_partitions))