mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make custom_partitioning support multiple return values.
PiperOrigin-RevId: 533584581
This commit is contained in:
parent
1d20d2f301
commit
56ca8af9bb
@ -46,15 +46,26 @@ def _resolve_kwargs(fun, args, kwargs):
|
||||
class _ShardingCallbackInfo:
|
||||
|
||||
def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding,
|
||||
infer_sharding_from_operands, module_context, mesh, static_args):
|
||||
in_tree, out_tree, infer_sharding_from_operands, module_context, mesh,
|
||||
static_args):
|
||||
self.propagate_user_sharding = propagate_user_sharding
|
||||
self.partition = partition
|
||||
self.to_mesh_pspec_sharding = to_mesh_pspec_sharding
|
||||
self.in_tree = in_tree
|
||||
self.out_tree = out_tree
|
||||
self.infer_sharding_from_operands = infer_sharding_from_operands
|
||||
self.module_context = module_context
|
||||
self.mesh = mesh
|
||||
self.static_args = static_args
|
||||
|
||||
def unflatten_arg_shapes(self, arg_shapes, arg_shardings):
|
||||
return self.in_tree.unflatten(
|
||||
[
|
||||
_to_jax_sharded_shape(s, self.to_mesh_pspec_sharding(sharding))
|
||||
for s, sharding in zip(arg_shapes, arg_shardings)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
_sharding_callbacks = weakref.WeakValueDictionary() # type: ignore
|
||||
|
||||
@ -71,51 +82,81 @@ def _to_jax_sharded_shape(s, sharding):
|
||||
)
|
||||
|
||||
|
||||
def _custom_partitioning_propagate_user_sharding(sharding, shape, backend_string):
|
||||
def _pack_result_sharding(shape, result_shardings):
|
||||
if shape.is_tuple():
|
||||
return xc.HloSharding.tuple_sharding(shape, result_shardings)
|
||||
else:
|
||||
return result_shardings[0]
|
||||
|
||||
|
||||
def _flatten_sharding(tree, shardings, shapes):
|
||||
return [
|
||||
_to_hlo_sharding(sharding, len(shape.dimensions()))
|
||||
for sharding, shape in zip(
|
||||
tree.flatten_up_to(shardings), shapes
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _custom_partitioning_propagate_user_sharding(user_sharding, shape,
|
||||
backend_string):
|
||||
info = _sharding_callbacks[backend_string]
|
||||
if info.propagate_user_sharding is None:
|
||||
return sharding
|
||||
user_shape = _to_jax_sharded_shape(
|
||||
shape, info.to_mesh_pspec_sharding(sharding)
|
||||
)
|
||||
result = info.propagate_user_sharding(*info.static_args, user_shape)
|
||||
return xc.HloSharding.from_proto(
|
||||
result._to_xla_op_sharding(len(user_shape.shape)))
|
||||
return user_sharding
|
||||
if shape.is_tuple():
|
||||
user_shapes = shape.tuple_shapes()
|
||||
user_shardings = user_sharding.tuple_elements()
|
||||
else:
|
||||
user_shapes = (shape,)
|
||||
user_shardings = (user_sharding,)
|
||||
user_shape = info.out_tree.unflatten(
|
||||
[
|
||||
_to_jax_sharded_shape(s, info.to_mesh_pspec_sharding(sharding))
|
||||
for s, sharding in zip(user_shapes, user_shardings)
|
||||
]
|
||||
)
|
||||
result_sharding = info.propagate_user_sharding(*info.static_args, user_shape)
|
||||
result_shardings = _flatten_sharding(
|
||||
info.out_tree, result_sharding, user_shapes)
|
||||
return _pack_result_sharding(shape, result_shardings)
|
||||
|
||||
|
||||
def _to_hlo_sharding(sharding, num_dimensions):
|
||||
if not isinstance(sharding, jax.sharding.Sharding):
|
||||
raise ValueError("Custom Partitioning rules must return shardings.")
|
||||
return xc.HloSharding.from_proto(sharding._to_xla_op_sharding(num_dimensions))
|
||||
|
||||
|
||||
def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
result_sharding, backend_string):
|
||||
info = _sharding_callbacks[backend_string]
|
||||
if result_shape.is_tuple():
|
||||
result_shapes = result_shape.tuple_shapes()
|
||||
result_shardings = result_sharding.tuple_elements()
|
||||
else:
|
||||
result_shapes = (result_shape,)
|
||||
result_shardings = (result_sharding,)
|
||||
lower_fn, result_sharding, arg_shardings = info.partition(
|
||||
*info.static_args,
|
||||
[
|
||||
_to_jax_sharded_shape(
|
||||
s, info.to_mesh_pspec_sharding(sharding)
|
||||
)
|
||||
for s, sharding in zip(arg_shapes, arg_shardings)
|
||||
],
|
||||
_to_jax_sharded_shape(
|
||||
result_shape, info.to_mesh_pspec_sharding(result_sharding)
|
||||
)
|
||||
*info.static_args, info.unflatten_arg_shapes(arg_shapes, arg_shardings),
|
||||
info.out_tree.unflatten(
|
||||
[
|
||||
_to_jax_sharded_shape(s, info.to_mesh_pspec_sharding(sharding))
|
||||
for s, sharding in zip(result_shapes, result_shardings)
|
||||
]
|
||||
)
|
||||
)
|
||||
module_context = info.module_context
|
||||
|
||||
def to_hlo_sharding(sharding, shape):
|
||||
return xc.HloSharding.from_proto(
|
||||
sharding._to_xla_op_sharding(len(shape.dimensions())))
|
||||
|
||||
result_sharding = to_hlo_sharding(result_sharding, result_shape)
|
||||
arg_shardings = [
|
||||
to_hlo_sharding(sharding, s)
|
||||
for sharding, s in zip(arg_shardings, arg_shapes)
|
||||
]
|
||||
result_shardings = _flatten_sharding(
|
||||
info.out_tree, result_sharding, result_shapes)
|
||||
arg_shardings = _flatten_sharding(info.in_tree, arg_shardings, arg_shapes)
|
||||
tiled_args = [
|
||||
_to_jax_shape(sharding.tile(s))
|
||||
for sharding, s in zip(arg_shardings, arg_shapes)
|
||||
]
|
||||
tiled_results = [
|
||||
_to_jax_shape(sharding.tile(s))
|
||||
for sharding, s in zip([result_sharding], [result_shape])
|
||||
for sharding, s in zip(result_shardings, result_shapes)
|
||||
]
|
||||
closed_jaxpr = jax.make_jaxpr(
|
||||
lower_fn, axis_env=list(info.mesh.shape.items()))(*tiled_args)
|
||||
@ -131,25 +172,26 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
platform=module_context.platform,
|
||||
backend_or_name=module_context.backend_or_name,
|
||||
axis_context=axis_context.extend_manual(frozenset(info.mesh.axis_names)))
|
||||
result_sharding = _pack_result_sharding(result_shape, result_shardings)
|
||||
return built, arg_shardings, result_sharding
|
||||
|
||||
|
||||
def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,
|
||||
shape, backend_string):
|
||||
result_shape,
|
||||
backend_string):
|
||||
info = _sharding_callbacks[backend_string]
|
||||
result_shape = _to_jax_shape(shape)
|
||||
result = info.infer_sharding_from_operands(
|
||||
if result_shape.is_tuple():
|
||||
result_shapes = result_shape.tuple_shapes()
|
||||
else:
|
||||
result_shapes = (result_shape,)
|
||||
result_sharding = info.infer_sharding_from_operands(
|
||||
*info.static_args,
|
||||
[
|
||||
_to_jax_sharded_shape(
|
||||
s, info.to_mesh_pspec_sharding(sharding)
|
||||
)
|
||||
for s, sharding in zip(arg_shapes, arg_shardings)
|
||||
],
|
||||
result_shape
|
||||
info.unflatten_arg_shapes(arg_shapes, arg_shardings),
|
||||
info.out_tree.unflatten([_to_jax_shape(s) for s in result_shapes]),
|
||||
)
|
||||
return xc.HloSharding.from_proto(
|
||||
result._to_xla_op_sharding(len(result_shape.shape)))
|
||||
result_shardings = _flatten_sharding(
|
||||
info.out_tree, result_sharding, result_shapes)
|
||||
return _pack_result_sharding(result_shape, result_shardings)
|
||||
|
||||
|
||||
custom_partitioning_p = core.Primitive("custom_partitioning")
|
||||
@ -270,7 +312,7 @@ class custom_partitioning:
|
||||
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
|
||||
return fft, \
|
||||
supported_sharding(arg_shardings[0], arg_shapes[0]), \
|
||||
[supported_sharding(arg_shardings[0], arg_shapes[0])]
|
||||
(supported_sharding(arg_shardings[0], arg_shapes[0]),)
|
||||
|
||||
def infer_sharding_from_operands(arg_shapes, result_shape):
|
||||
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
|
||||
@ -434,10 +476,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
op_sharding.to_proto(), mesh)[0].get_partition_spec()
|
||||
return jax.sharding.NamedSharding(mesh, pspec)
|
||||
|
||||
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
|
||||
to_mesh_pspec_sharding,
|
||||
infer_sharding_from_operands,
|
||||
ctx.module_context, mesh, static_args)
|
||||
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding,
|
||||
partition, to_mesh_pspec_sharding, in_tree, out_tree,
|
||||
infer_sharding_from_operands, ctx.module_context, mesh, static_args)
|
||||
key = str(id(sharding_callback_info))
|
||||
_sharding_callbacks[key] = sharding_callback_info
|
||||
# We need to make sure `sharding_callback_info` is still alive when the SPMD
|
||||
|
@ -1163,9 +1163,12 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
def test_custom_partitioner(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
|
||||
if xla_extension_version < 154:
|
||||
self.skipTest('Requires xla_extension_version >= 154')
|
||||
|
||||
def partition(precision, arg_shapes, result_shape):
|
||||
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
|
||||
result_sharding = result_shape.sharding
|
||||
result_sharding = result_shape[0].sharding
|
||||
self.assertEqual(arg_shardings[0], result_sharding)
|
||||
self.assertEqual(P('x'), result_sharding.spec)
|
||||
self.assertEqual(P('y'), arg_shardings[1].spec)
|
||||
@ -1173,10 +1176,12 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
def lower_fn(x, y):
|
||||
axis_name = arg_shardings[1].spec[0][0]
|
||||
i = jax.lax.axis_index(axis_name)
|
||||
return jax.lax.psum(
|
||||
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name))
|
||||
z = jax.lax.psum(
|
||||
jax.lax.dynamic_slice(x, (0, i * 8), (8, 8)) @ y, (axis_name)
|
||||
)
|
||||
return z, z * z
|
||||
|
||||
return lower_fn, result_sharding, arg_shardings
|
||||
return lower_fn, (result_sharding, result_sharding), arg_shardings
|
||||
|
||||
def infer_sharding_from_operands(precision, arg_shapes, result_shape):
|
||||
arg_shardings = jax.tree_map(lambda s: s.sharding, arg_shapes)
|
||||
@ -1186,11 +1191,13 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
None for _ in range(len(x_shape.shape) - len(x_shard.spec)))
|
||||
y_names = tuple(y_shard.spec) + tuple(
|
||||
None for _ in range(len(y_shape.shape) - len(y_shard.spec)))
|
||||
return NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:])))
|
||||
z_shard = NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:])))
|
||||
return z_shard, z_shard
|
||||
|
||||
@partial(custom_partitioning, static_argnums=(2,))
|
||||
def f(x, y, precision=None):
|
||||
return jnp.matmul(x, y, precision=precision)
|
||||
z = jnp.matmul(x, y, precision=precision)
|
||||
return z, z * z
|
||||
|
||||
f.def_partition(
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
@ -1216,7 +1223,7 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
return (
|
||||
lower_fn,
|
||||
arg_shapes[0].sharding,
|
||||
[arg_shapes[0].sharding],
|
||||
(arg_shapes[0].sharding,),
|
||||
)
|
||||
|
||||
def infer_sharding_from_operands(arg_shapes, result_shape):
|
||||
@ -1242,7 +1249,6 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
x = np.asarray(np.random.randint(0, 20, (32, 16)), dtype=np.float32)
|
||||
self.assertArraysEqual(x + x, pjit_f(x))
|
||||
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_custom_partitioner_sharding_override(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
@ -1255,7 +1261,7 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
return (
|
||||
lower_fn,
|
||||
NamedSharding(y_shard.mesh, P(None)),
|
||||
[NamedSharding(y_shard.mesh, P(None))],
|
||||
(NamedSharding(y_shard.mesh, P(None)),),
|
||||
)
|
||||
|
||||
def infer_sharding_from_operands(arg_shapes, result_shape):
|
||||
@ -1289,7 +1295,7 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
return (
|
||||
lower_fn,
|
||||
NamedSharding(y_shard.mesh, P(None)),
|
||||
[NamedSharding(y_shard.mesh, P(None, 'x'))],
|
||||
(NamedSharding(y_shard.mesh, P(None, 'x')),),
|
||||
)
|
||||
|
||||
def infer_sharding_from_operands(arg_shapes, result_shape):
|
||||
|
Loading…
x
Reference in New Issue
Block a user