For custom_partitioning, directly emit call when inside of a shard_map.

PiperOrigin-RevId: 592011427
This commit is contained in:
Parker Schuh 2023-12-18 14:32:04 -08:00 committed by jax authors
parent afdb7370b9
commit 7ba8622719
3 changed files with 60 additions and 0 deletions

View File

@ -481,6 +481,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
static_args):
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)
if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment

View File

@ -1235,6 +1235,7 @@ jax_test(
"notsan",
], # Times out under *SAN.
deps = [
"//jax:experimental",
"//jax:tree_util",
],
)

View File

@ -44,6 +44,7 @@ from jax._src import linear_util as lu
from jax._src import tree_util
import jax.numpy as jnp
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.shard_map import shard_map
config.parse_flags_with_absl()
@ -1684,6 +1685,61 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
@jtu.pytest_mark_if_available('multiaccelerator')
class CustomPartitionerTest(jtu.JaxTestCase):
def skip_if_custom_partitioning_not_supported(self):
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
if xla_bridge.using_pjrt_c_api():
raise unittest.SkipTest('custom partitioning not implemented in PJRT C API')
def test_custom_partitioning(self):
self.skip_if_custom_partitioning_not_supported()
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
assert a.addressable_data(0).shape == (4, 2)
def partition(mesh, arg_shapes, result_shape):
def lower_fn(x):
return x
return (
mesh,
lower_fn,
arg_shapes[0].sharding,
(arg_shapes[0].sharding,),
)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
return arg_shapes[0].sharding
def propagate_user_sharding(mesh, user_shape):
return user_shape.sharding
@custom_partitioning
def f(x):
return x
f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
propagate_user_sharding=propagate_user_sharding,
)
@jax.jit
def fwd(a):
c = shard_map(
f,
mesh,
check_rep=False,
in_specs=(P('z', ('x', 'y')),),
out_specs=P('z', ('x', 'y')))(a)
return c
c = fwd(a)
self.assertEqual(c.addressable_data(0).shape, (4, 2))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())