mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this:
``` with mesh: out = pjit(lambda: 1)() ``` The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead. This is also required for `Shardy` integration. PiperOrigin-RevId: 658842350
This commit is contained in:
parent
ac52890e3d
commit
958234a9c1
@ -875,7 +875,7 @@ def _check_lowering(lowering) -> None:
|
||||
# Check that we do not see new compile_args. When we add a compile_args it is
|
||||
# safe to add it to the allowed_compile_args if it does not change the semantics
|
||||
# or the calling convention of the lowered module.
|
||||
allowed_compile_args = [
|
||||
allowed_compile_args = {
|
||||
"backend", "platforms", "mesh", "global_in_avals",
|
||||
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
||||
"mut", "spmd_lowering", "auto_spmd_lowering",
|
||||
@ -883,7 +883,7 @@ def _check_lowering(lowering) -> None:
|
||||
"keepalive", "host_callbacks", "pmap_nreps", "committed",
|
||||
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
|
||||
"all_default_mem_kind", "in_layouts", "out_layouts", "all_args_info",
|
||||
"pgle_profiler", "intermediate_shardings"]
|
||||
"pgle_profiler", "intermediate_shardings", "context_mesh"}
|
||||
for compile_arg in lowering.compile_args.keys():
|
||||
if compile_arg not in allowed_compile_args:
|
||||
raise NotImplementedError(f"Unrecognized lowered.compile_args[{compile_arg}]")
|
||||
|
@ -2114,7 +2114,7 @@ def lower_sharding_computation(
|
||||
donated_invars: Sequence[bool],
|
||||
*,
|
||||
keep_unused: bool,
|
||||
devices_from_context: Sequence[xc.Device] | None,
|
||||
context_mesh: mesh_lib.Mesh | None,
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None,
|
||||
@ -2157,6 +2157,8 @@ def lower_sharding_computation(
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
devices_from_context = (None if context_mesh is None or context_mesh.empty
|
||||
else context_mesh._flat_devices_tuple)
|
||||
# Device assignment across all inputs, outputs and shardings inside jaxpr
|
||||
# should be the same.
|
||||
unique_intermediate_shardings = list(util.stable_unique(
|
||||
@ -2253,7 +2255,8 @@ def lower_sharding_computation(
|
||||
all_default_mem_kind=all_default_mem_kind,
|
||||
all_args_info=all_args_info,
|
||||
pgle_profiler=pgle_profiler,
|
||||
intermediate_shardings=[s for s, _ in unique_intermediate_shardings])
|
||||
intermediate_shardings=[s for s, _ in unique_intermediate_shardings],
|
||||
context_mesh=context_mesh)
|
||||
|
||||
|
||||
def _to_logical_sharding(
|
||||
@ -2472,7 +2475,7 @@ def _get_out_sharding_from_orig_sharding(
|
||||
|
||||
def maybe_recover_user_shardings(
|
||||
old_shardings, new_shardings, old_avals, new_avals,
|
||||
intermediate_shardings=None):
|
||||
intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None):
|
||||
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
|
||||
return new_shardings
|
||||
|
||||
@ -2487,6 +2490,11 @@ def maybe_recover_user_shardings(
|
||||
return _get_out_sharding_from_orig_sharding(
|
||||
new_shardings, new_avals, i, None)
|
||||
|
||||
if context_mesh is not None and not context_mesh.empty:
|
||||
return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh)
|
||||
if isinstance(n, GSPMDSharding) else n
|
||||
for n in new_shardings]
|
||||
|
||||
return new_shardings
|
||||
|
||||
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
|
||||
@ -2775,6 +2783,7 @@ class UnloadedMeshExecutable:
|
||||
compiler_options=None,
|
||||
pgle_profiler: profiler.PGLEProfiler | None = None,
|
||||
intermediate_shardings: Sequence[JSharding] | None = None,
|
||||
context_mesh: mesh_lib.Mesh | None = None
|
||||
) -> MeshExecutable:
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
@ -2832,7 +2841,7 @@ class UnloadedMeshExecutable:
|
||||
|
||||
out_shardings = maybe_recover_user_shardings(
|
||||
in_shardings, out_shardings, global_in_avals, global_out_avals,
|
||||
intermediate_shardings)
|
||||
intermediate_shardings, context_mesh)
|
||||
|
||||
out_shardings = finalize_out_shardings(out_shardings, da)
|
||||
|
||||
|
@ -1766,9 +1766,7 @@ def _pjit_lower_cached(
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
keep_unused=keep_unused,
|
||||
devices_from_context=(
|
||||
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
|
||||
keep_unused=keep_unused, context_mesh=mesh,
|
||||
lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=lowering_parameters,
|
||||
pgle_profiler=pgle_profiler)
|
||||
|
@ -2516,6 +2516,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
with mesh:
|
||||
out = pjit(lambda: 1)()
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P()))
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertEqual(out, 1)
|
||||
|
||||
@ -2873,6 +2874,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
f = pjit(mul, device=jax.devices()[1])
|
||||
|
||||
x = jnp.arange(8).reshape(4, 2)
|
||||
f_out = f(x)
|
||||
f_out2 = f(f_out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user