[sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)

PiperOrigin-RevId: 716653203
This commit is contained in:
Yash Katariya 2025-01-17 06:57:57 -08:00 committed by jax authors
parent 7cac76d346
commit ce85b89884
2 changed files with 11 additions and 0 deletions

View File

@ -4963,6 +4963,8 @@ def _split_on_one_axis(op_shape, new_sizes, name):
f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}')
temp = [new_sizes[j]]
while math.prod(temp) != op_shape[i]:
if math.prod(temp) > op_shape[i]:
return False, []
j += 1
temp.append(new_sizes[j])
out.append(temp)

View File

@ -5346,6 +5346,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
('split_4_error', (4, 6, 8), (4, 2, 3, 8),
P('x', 'y', None), None, 'Split axis cannot be sharded'
),
('split_5_error', (4, 6, 8), (4, 4, 2, 6),
P('x', None, None), None, 'This reshape is not supported'
),
('split_6_error', (4, 8, 9), (4, 2, 2, 3, 3, 2),
P('x', None, None), None, 'This reshape is not supported'
),
('merge_1', (4, 2, 3, 8), (4, 6, 8),
P('x', None, None, 'y'), P('x', None, 'y'), ''
),
@ -5362,6 +5368,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
('merge_5_error', (4, 2, 3, 8), (4, 6, 8),
P(None, 'y', None, 'x'), None, 'Merged axis cannot be sharded'
),
('merge_6_error', (4, 2, 3, 8), (4, 8, 6),
P(None, 'y', None, 'x'), None, 'This reshape is not supported'
),
)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_reshape_split_merge_one_axis(self, src_shape, dst_shape, src_spec,