mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
This commit is contained in:
parent
afdb7370b9
commit
7ba8622719
@ -481,6 +481,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
static_args):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
|
||||
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
|
||||
return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)
|
||||
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
devices = axis_context.device_assignment
|
||||
|
@ -1235,6 +1235,7 @@ jax_test(
|
||||
"notsan",
|
||||
], # Times out under *SAN.
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
"//jax:tree_util",
|
||||
],
|
||||
)
|
||||
|
@ -44,6 +44,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -1684,6 +1685,61 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
tol = 1e-2 if jtu.test_device_matches(['tpu']) else None
|
||||
self.assertAllClose(ans, expected, check_dtypes=False, atol=tol, rtol=tol)
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
|
||||
def skip_if_custom_partitioning_not_supported(self):
|
||||
if jtu.is_cloud_tpu():
|
||||
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
|
||||
if xla_bridge.using_pjrt_c_api():
|
||||
raise unittest.SkipTest('custom partitioning not implemented in PJRT C API')
|
||||
|
||||
def test_custom_partitioning(self):
|
||||
self.skip_if_custom_partitioning_not_supported()
|
||||
|
||||
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
||||
assert a.addressable_data(0).shape == (4, 2)
|
||||
|
||||
def partition(mesh, arg_shapes, result_shape):
|
||||
def lower_fn(x):
|
||||
return x
|
||||
|
||||
return (
|
||||
mesh,
|
||||
lower_fn,
|
||||
arg_shapes[0].sharding,
|
||||
(arg_shapes[0].sharding,),
|
||||
)
|
||||
|
||||
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
|
||||
return arg_shapes[0].sharding
|
||||
|
||||
def propagate_user_sharding(mesh, user_shape):
|
||||
return user_shape.sharding
|
||||
|
||||
@custom_partitioning
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
f.def_partition(
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
partition=partition,
|
||||
propagate_user_sharding=propagate_user_sharding,
|
||||
)
|
||||
|
||||
@jax.jit
|
||||
def fwd(a):
|
||||
c = shard_map(
|
||||
f,
|
||||
mesh,
|
||||
check_rep=False,
|
||||
in_specs=(P('z', ('x', 'y')),),
|
||||
out_specs=P('z', ('x', 'y')))(a)
|
||||
return c
|
||||
|
||||
c = fwd(a)
|
||||
self.assertEqual(c.addressable_data(0).shape, (4, 2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user