Support tuples in custom_partitioning.

PiperOrigin-RevId: 738154413
This commit is contained in:
Parker Schuh 2025-03-18 14:56:23 -07:00 committed by jax authors
parent 080804c78d
commit 0fb59747f0
2 changed files with 42 additions and 1 deletions

View File

@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
for sharding, s in zip(result_shardings, result_shapes)
]
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
*tiled_args
*info.in_tree.unflatten(tiled_args)
)
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
[(t.shape, t.dtype) for t in tiled_results]):

View File

@ -1680,6 +1680,47 @@ class CustomPartitionerTest(jtu.JaxTestCase):
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
self.assertArraysEqual(x, jit_f(x))
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_custom_partitioner_pytree_inputs(self):
self.skip_if_custom_partitioning_not_supported()
def partition(mesh, arg_shapes, result_shape):
def lower_fn(xs):
x, y, z = xs
return x + y + z
return (
mesh,
lower_fn,
arg_shapes[0][0].sharding,
jax.tree.map(lambda x: x.sharding, arg_shapes),
)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
return arg_shapes[0][0].sharding
def propagate_user_sharding(mesh, user_shape):
return user_shape.sharding
@custom_partitioning
def f(xs):
x, y, z = xs
return x + y + z
f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
propagate_user_sharding=propagate_user_sharding,
sharding_rule='i j, i j, i j -> i j',
)
def f2(a):
return a + f((a, a, a))
pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x'))
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
self.assertArraysEqual(x * 4, pjit_f(x))
@jtu.pytest_mark_if_available('multiaccelerator')
class AutoShardingPjitTest(jtu.JaxTestCase):