mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Support tuples in custom_partitioning.
PiperOrigin-RevId: 738154413
This commit is contained in:
parent
080804c78d
commit
0fb59747f0
@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
for sharding, s in zip(result_shardings, result_shapes)
|
||||
]
|
||||
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
|
||||
*tiled_args
|
||||
*info.in_tree.unflatten(tiled_args)
|
||||
)
|
||||
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
|
||||
[(t.shape, t.dtype) for t in tiled_results]):
|
||||
|
@ -1680,6 +1680,47 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
jit_f = jax.jit(f, in_shardings=s, out_shardings=s)
|
||||
self.assertArraysEqual(x, jit_f(x))
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_custom_partitioner_pytree_inputs(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
|
||||
def partition(mesh, arg_shapes, result_shape):
|
||||
def lower_fn(xs):
|
||||
x, y, z = xs
|
||||
return x + y + z
|
||||
|
||||
return (
|
||||
mesh,
|
||||
lower_fn,
|
||||
arg_shapes[0][0].sharding,
|
||||
jax.tree.map(lambda x: x.sharding, arg_shapes),
|
||||
)
|
||||
|
||||
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
||||
return arg_shapes[0][0].sharding
|
||||
|
||||
def propagate_user_sharding(mesh, user_shape):
|
||||
return user_shape.sharding
|
||||
|
||||
@custom_partitioning
|
||||
def f(xs):
|
||||
x, y, z = xs
|
||||
return x + y + z
|
||||
|
||||
f.def_partition(
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
partition=partition,
|
||||
propagate_user_sharding=propagate_user_sharding,
|
||||
sharding_rule='i j, i j, i j -> i j',
|
||||
)
|
||||
|
||||
def f2(a):
|
||||
return a + f((a, a, a))
|
||||
|
||||
pjit_f = pjit(f2, in_shardings=(P(None, 'x')), out_shardings=P('x'))
|
||||
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
|
||||
self.assertArraysEqual(x * 4, pjit_f(x))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user