diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 658a6f7a2..537407151 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, for sharding, s in zip(result_shardings, result_shapes) ] closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))( - *tiled_args + *info.in_tree.unflatten(tiled_args) ) if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] != [(t.shape, t.dtype) for t in tiled_results]): diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 76336920b..293b37a9f 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1680,6 +1680,47 @@ class CustomPartitionerTest(jtu.JaxTestCase): jit_f = jax.jit(f, in_shardings=s, out_shardings=s) self.assertArraysEqual(x, jit_f(x)) + @jtu.with_mesh([('x', 4), ('y', 2)]) + def test_custom_partitioner_pytree_inputs(self): + self.skip_if_custom_partitioning_not_supported() + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(xs): + x, y, z = xs + return x + y + z + + return ( + mesh, + lower_fn, + arg_shapes[0][0].sharding, + jax.tree.map(lambda x: x.sharding, arg_shapes), + ) + + def infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return arg_shapes[0][0].sharding + + def propagate_user_sharding(mesh, user_shape): + return user_shape.sharding + + @custom_partitioning + def f(xs): + x, y, z = xs + return x + y + z + + f.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition, + propagate_user_sharding=propagate_user_sharding, + sharding_rule='i j, i j, i j -> i j', + ) + + def f2(a): + return a + f((a, a, a)) + + pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x')) + x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32) + self.assertArraysEqual(x * 4, pjit_f(x)) + @jtu.pytest_mark_if_available('multiaccelerator') class AutoShardingPjitTest(jtu.JaxTestCase):