mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[sharding_in_types] Error out for reshape for splits like this: (4, 6, 8)
-> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
This commit is contained in:
parent
7cac76d346
commit
ce85b89884
@ -4963,6 +4963,8 @@ def _split_on_one_axis(op_shape, new_sizes, name):
|
||||
f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}')
|
||||
temp = [new_sizes[j]]
|
||||
while math.prod(temp) != op_shape[i]:
|
||||
if math.prod(temp) > op_shape[i]:
|
||||
return False, []
|
||||
j += 1
|
||||
temp.append(new_sizes[j])
|
||||
out.append(temp)
|
||||
|
@ -5346,6 +5346,12 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
('split_4_error', (4, 6, 8), (4, 2, 3, 8),
|
||||
P('x', 'y', None), None, 'Split axis cannot be sharded'
|
||||
),
|
||||
('split_5_error', (4, 6, 8), (4, 4, 2, 6),
|
||||
P('x', None, None), None, 'This reshape is not supported'
|
||||
),
|
||||
('split_6_error', (4, 8, 9), (4, 2, 2, 3, 3, 2),
|
||||
P('x', None, None), None, 'This reshape is not supported'
|
||||
),
|
||||
('merge_1', (4, 2, 3, 8), (4, 6, 8),
|
||||
P('x', None, None, 'y'), P('x', None, 'y'), ''
|
||||
),
|
||||
@ -5362,6 +5368,9 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
('merge_5_error', (4, 2, 3, 8), (4, 6, 8),
|
||||
P(None, 'y', None, 'x'), None, 'Merged axis cannot be sharded'
|
||||
),
|
||||
('merge_6_error', (4, 2, 3, 8), (4, 8, 6),
|
||||
P(None, 'y', None, 'x'), None, 'This reshape is not supported'
|
||||
),
|
||||
)
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_reshape_split_merge_one_axis(self, src_shape, dst_shape, src_spec,
|
||||
|
Loading…
x
Reference in New Issue
Block a user