Make custom_partitioning support multiple return values.

PiperOrigin-RevId: 533584581
This commit is contained in:
Parker Schuh 2023-05-19 16:58:21 -07:00 committed by jax authors
parent 1d20d2f301
commit 56ca8af9bb
2 changed files with 103 additions and 56 deletions

View File

@ -46,15 +46,26 @@ def _resolve_kwargs(fun, args, kwargs):
class _ShardingCallbackInfo:
def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding,
infer_sharding_from_operands, module_context, mesh, static_args):
in_tree, out_tree, infer_sharding_from_operands, module_context, mesh,
static_args):
self.propagate_user_sharding = propagate_user_sharding
self.partition = partition
self.to_mesh_pspec_sharding = to_mesh_pspec_sharding
self.in_tree = in_tree
self.out_tree = out_tree
self.infer_sharding_from_operands = infer_sharding_from_operands
self.module_context = module_context
self.mesh = mesh
self.static_args = static_args
def unflatten_arg_shapes(self, arg_shapes, arg_shardings):
return self.in_tree.unflatten(
[
_to_jax_sharded_shape(s, self.to_mesh_pspec_sharding(sharding))
for s, sharding in zip(arg_shapes, arg_shardings)
]
)
_sharding_callbacks = weakref.WeakValueDictionary() # type: ignore
@ -71,51 +82,81 @@ def _to_jax_sharded_shape(s, sharding):
)
def _custom_partitioning_propagate_user_sharding(sharding, shape, backend_string):
def _pack_result_sharding(shape, result_shardings):
if shape.is_tuple():
return xc.HloSharding.tuple_sharding(shape, result_shardings)
else:
return result_shardings[0]
def _flatten_sharding(tree, shardings, shapes):
return [
_to_hlo_sharding(sharding, len(shape.dimensions()))
for sharding, shape in zip(
tree.flatten_up_to(shardings), shapes
)
]
def _custom_partitioning_propagate_user_sharding(user_sharding, shape,
backend_string):
info = _sharding_callbacks[backend_string]
if info.propagate_user_sharding is None:
return sharding
user_shape = _to_jax_sharded_shape(
shape, info.to_mesh_pspec_sharding(sharding)
)
result = info.propagate_user_sharding(*info.static_args, user_shape)
return xc.HloSharding.from_proto(
result._to_xla_op_sharding(len(user_shape.shape)))
return user_sharding
if shape.is_tuple():
user_shapes = shape.tuple_shapes()
user_shardings = user_sharding.tuple_elements()
else:
user_shapes = (shape,)
user_shardings = (user_sharding,)
user_shape = info.out_tree.unflatten(
[
_to_jax_sharded_shape(s, info.to_mesh_pspec_sharding(sharding))
for s, sharding in zip(user_shapes, user_shardings)
]
)
result_sharding = info.propagate_user_sharding(*info.static_args, user_shape)
result_shardings = _flatten_sharding(
info.out_tree, result_sharding, user_shapes)
return _pack_result_sharding(shape, result_shardings)
def _to_hlo_sharding(sharding, num_dimensions):
if not isinstance(sharding, jax.sharding.Sharding):
raise ValueError("Custom Partitioning rules must return shardings.")
return xc.HloSharding.from_proto(sharding._to_xla_op_sharding(num_dimensions))
def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
result_sharding, backend_string):
info = _sharding_callbacks[backend_string]
if result_shape.is_tuple():
result_shapes = result_shape.tuple_shapes()
result_shardings = result_sharding.tuple_elements()
else:
result_shapes = (result_shape,)
result_shardings = (result_sharding,)
lower_fn, result_sharding, arg_shardings = info.partition(
*info.static_args,
[
_to_jax_sharded_shape(
s, info.to_mesh_pspec_sharding(sharding)
)
for s, sharding in zip(arg_shapes, arg_shardings)
],
_to_jax_sharded_shape(
result_shape, info.to_mesh_pspec_sharding(result_sharding)
)
*info.static_args, info.unflatten_arg_shapes(arg_shapes, arg_shardings),
info.out_tree.unflatten(
[
_to_jax_sharded_shape(s, info.to_mesh_pspec_sharding(sharding))
for s, sharding in zip(result_shapes, result_shardings)
]
)
)
module_context = info.module_context
def to_hlo_sharding(sharding, shape):
return xc.HloSharding.from_proto(
sharding._to_xla_op_sharding(len(shape.dimensions())))
result_sharding = to_hlo_sharding(result_sharding, result_shape)
arg_shardings = [
to_hlo_sharding(sharding, s)
for sharding, s in zip(arg_shardings, arg_shapes)
]
result_shardings = _flatten_sharding(
info.out_tree, result_sharding, result_shapes)
arg_shardings = _flatten_sharding(info.in_tree, arg_shardings, arg_shapes)
tiled_args = [
_to_jax_shape(sharding.tile(s))
for sharding, s in zip(arg_shardings, arg_shapes)
]
tiled_results = [
_to_jax_shape(sharding.tile(s))
for sharding, s in zip([result_sharding], [result_shape])
for sharding, s in zip(result_shardings, result_shapes)
]
closed_jaxpr = jax.make_jaxpr(
lower_fn, axis_env=list(info.mesh.shape.items()))(*tiled_args)
@ -131,25 +172,26 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
platform=module_context.platform,
backend_or_name=module_context.backend_or_name,
axis_context=axis_context.extend_manual(frozenset(info.mesh.axis_names)))
result_sharding = _pack_result_sharding(result_shape, result_shardings)
return built, arg_shardings, result_sharding
def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,
shape, backend_string):
result_shape,
backend_string):
info = _sharding_callbacks[backend_string]
result_shape = _to_jax_shape(shape)
result = info.infer_sharding_from_operands(
if result_shape.is_tuple():
result_shapes = result_shape.tuple_shapes()
else:
result_shapes = (result_shape,)
result_sharding = info.infer_sharding_from_operands(
*info.static_args,
[
_to_jax_sharded_shape(
s, info.to_mesh_pspec_sharding(sharding)
)
for s, sharding in zip(arg_shapes, arg_shardings)
],
result_shape
info.unflatten_arg_shapes(arg_shapes, arg_shardings),
info.out_tree.unflatten([_to_jax_shape(s) for s in result_shapes]),
)
return xc.HloSharding.from_proto(
result._to_xla_op_sharding(len(result_shape.shape)))
result_shardings = _flatten_sharding(
info.out_tree, result_sharding, result_shapes)
return _pack_result_sharding(result_shape, result_shardings)
custom_partitioning_p = core.Primitive("custom_partitioning")
@ -270,7 +312,7 @@ class custom_partitioning:
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
return fft, \
supported_sharding(arg_shardings[0], arg_shapes[0]), \
[supported_sharding(arg_shardings[0], arg_shapes[0])]
(supported_sharding(arg_shardings[0], arg_shapes[0]),)
def infer_sharding_from_operands(arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
@ -434,10 +476,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
op_sharding.to_proto(), mesh)[0].get_partition_spec()
return jax.sharding.NamedSharding(mesh, pspec)
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
to_mesh_pspec_sharding,
infer_sharding_from_operands,
ctx.module_context, mesh, static_args)
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding,
partition, to_mesh_pspec_sharding, in_tree, out_tree,
infer_sharding_from_operands, ctx.module_context, mesh, static_args)
key = str(id(sharding_callback_info))
_sharding_callbacks[key] = sharding_callback_info
# We need to make sure `sharding_callback_info` is still alive when the SPMD

View File

@ -1163,9 +1163,12 @@ class CustomPartitionerTest(jtu.JaxTestCase):
def test_custom_partitioner(self):
self.skip_if_custom_partitioning_not_supported()
if xla_extension_version < 154:
self.skipTest('Requires xla_extension_version >= 154')
def partition(precision, arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
result_sharding = result_shape.sharding
result_sharding = result_shape[0].sharding
self.assertEqual(arg_shardings[0], result_sharding)
self.assertEqual(P('x'), result_sharding.spec)
self.assertEqual(P('y'), arg_shardings[1].spec)
@ -1173,10 +1176,12 @@ class CustomPartitionerTest(jtu.JaxTestCase):
def lower_fn(x, y):
axis_name = arg_shardings[1].spec[0][0]
i = jax.lax.axis_index(axis_name)
return jax.lax.psum(
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name))
z = jax.lax.psum(
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name)
)
return z, z * z
return lower_fn, result_sharding, arg_shardings
return lower_fn, (result_sharding, result_sharding), arg_shardings
def infer_sharding_from_operands(precision, arg_shapes, result_shape):
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
@ -1186,11 +1191,13 @@ class CustomPartitionerTest(jtu.JaxTestCase):
None for _ in range(len(x_shape.shape) - len(x_shard.spec)))
y_names = tuple(y_shard.spec) + tuple(
None for _ in range(len(y_shape.shape) - len(y_shard.spec)))
return NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:])))
z_shard = NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:])))
return z_shard, z_shard
@partial(custom_partitioning, static_argnums=(2,))
def f(x, y, precision=None):
return jnp.matmul(x, y, precision=precision)
z = jnp.matmul(x, y, precision=precision)
return z, z * z
f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
@ -1216,7 +1223,7 @@ class CustomPartitionerTest(jtu.JaxTestCase):
return (
lower_fn,
arg_shapes[0].sharding,
[arg_shapes[0].sharding],
(arg_shapes[0].sharding,),
)
def infer_sharding_from_operands(arg_shapes, result_shape):
@ -1242,7 +1249,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
self.assertArraysEqual(x + x, pjit_f(x))
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_custom_partitioner_sharding_override(self):
self.skip_if_custom_partitioning_not_supported()
@ -1255,7 +1261,7 @@ class CustomPartitionerTest(jtu.JaxTestCase):
return (
lower_fn,
NamedSharding(y_shard.mesh, P(None)),
[NamedSharding(y_shard.mesh, P(None))],
(NamedSharding(y_shard.mesh, P(None)),),
)
def infer_sharding_from_operands(arg_shapes, result_shape):
@ -1289,7 +1295,7 @@ class CustomPartitionerTest(jtu.JaxTestCase):
return (
lower_fn,
NamedSharding(y_shard.mesh, P(None)),
[NamedSharding(y_shard.mesh, P(None, 'x'))],
(NamedSharding(y_shard.mesh, P(None, 'x')),),
)
def infer_sharding_from_operands(arg_shapes, result_shape):