mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Adds a debugging message to assert, otherwise the error is pretty cryptic.
PiperOrigin-RevId: 747657234
This commit is contained in:
parent
4fa3cd91d3
commit
0ed0fb7c54
@ -898,7 +898,11 @@ def parse_flatten_op_sharding(
|
||||
while dim_size > 1:
|
||||
axis = next(mesh_axis)
|
||||
axis_size = mesh_shape[axis]
|
||||
assert dim_size % axis_size == 0
|
||||
if dim_size % axis_size != 0:
|
||||
raise ValueError(
|
||||
f'{shape=} is incompatible with {mesh_shape=}: '
|
||||
f'{dim_size=} is not divisible by {axis_size=}.'
|
||||
)
|
||||
dim_size //= axis_size
|
||||
dim_partitions.append(axis)
|
||||
partitions.append(tuple(dim_partitions))
|
||||
|
Loading…
x
Reference in New Issue
Block a user